[chore] exclude paddleocr directory from pre-commit hooks
@@ -4,8 +4,10 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||
- id: ruff-format
|
||||
args: [--config=pyproject.toml]
|
||||
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
@@ -1,50 +0,0 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import datasets
|
||||
import json
|
||||
|
||||
DIR_URL = Path('absolute/path/to/dataset/directory')
|
||||
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
|
||||
|
||||
|
||||
class LatexFormulas(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = []
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
features=datasets.Features({
|
||||
"image": datasets.Image(),
|
||||
"latex_formula": datasets.Value("string")
|
||||
})
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
dir_path = Path(dl_manager.download(str(DIR_URL)))
|
||||
assert dir_path.is_dir()
|
||||
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
'dir_path': dir_path,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, dir_path: Path):
|
||||
images_path = dir_path / 'images'
|
||||
formulas_path = dir_path / 'formulas.jsonl'
|
||||
|
||||
img2formula = {}
|
||||
with formulas_path.open('r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
single_json = json.loads(line)
|
||||
img2formula[single_json['img_name']] = single_json['formula']
|
||||
|
||||
for img_path in images_path.iterdir():
|
||||
if img_path.suffix not in ['.jpg', '.png']:
|
||||
continue
|
||||
yield str(img_path), {
|
||||
"image": Image.open(img_path),
|
||||
"latex_formula": img2formula[img_path.name]
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from .mix_inference import mix_inference
|
||||
@@ -1,65 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import cv2 as cv
|
||||
from pathlib import Path
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
from models.ocr_model.utils.inference import inference as latex_inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-img_dir',
|
||||
type=str,
|
||||
help='path to the input image',
|
||||
default='./detect_results/subimages'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-output_dir',
|
||||
type=str,
|
||||
help='path to the output dir',
|
||||
default='./rec_results'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inference-mode',
|
||||
type=str,
|
||||
default='cpu',
|
||||
help='Inference mode, select one of cpu, cuda, or mps'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num-beam',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of beam search for decoding'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print('Loading model and tokenizer...')
|
||||
latex_rec_model = TexTeller.from_pretrained()
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
print('Model and tokenizer loaded.')
|
||||
|
||||
# Create the output directory if it doesn't exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Loop through all images in the input directory
|
||||
for filename in os.listdir(args.img_dir):
|
||||
img_path = os.path.join(args.img_dir, filename)
|
||||
img = cv.imread(img_path)
|
||||
|
||||
if img is not None:
|
||||
print(f'Inference for {filename}...')
|
||||
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
|
||||
res = to_katex(res[0])
|
||||
|
||||
# Save the recognition result to a text file
|
||||
output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(res)
|
||||
|
||||
print(f'Result saved to {output_file}')
|
||||
else:
|
||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
||||
@@ -1,85 +1,96 @@
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
import onnxruntime
|
||||
from pathlib import Path
|
||||
|
||||
from models.det_model.inference import PredictConfig, predict_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
|
||||
default="./models/det_model/model/infer_cfg.yml")
|
||||
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
|
||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True)
|
||||
|
||||
|
||||
def get_test_images(infer_dir, infer_img):
|
||||
"""
|
||||
Get image path list in TEST mode
|
||||
"""
|
||||
assert infer_img is not None or infer_dir is not None, \
|
||||
"--image_file or --image_dir should be set"
|
||||
assert infer_img is None or os.path.isfile(infer_img), \
|
||||
"{} is not a file".format(infer_img)
|
||||
assert infer_dir is None or os.path.isdir(infer_dir), \
|
||||
"{} is not a directory".format(infer_dir)
|
||||
|
||||
# infer_img has a higher priority
|
||||
if infer_img and os.path.isfile(infer_img):
|
||||
return [infer_img]
|
||||
|
||||
images = set()
|
||||
infer_dir = os.path.abspath(infer_dir)
|
||||
assert os.path.isdir(infer_dir), \
|
||||
"infer_dir {} is not a directory".format(infer_dir)
|
||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||
exts += [ext.upper() for ext in exts]
|
||||
for ext in exts:
|
||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
||||
images = list(images)
|
||||
|
||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
||||
print("Found {} inference images in total.".format(len(images)))
|
||||
|
||||
return images
|
||||
|
||||
def download_file(url, filename):
|
||||
print(f"Downloading {filename}...")
|
||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
||||
print("Download complete.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
cur_path = os.getcwd()
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if not os.path.exists(FLAGS.infer_cfg):
|
||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
||||
|
||||
if not os.path.exists(FLAGS.onnx_file):
|
||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
||||
|
||||
# load image list
|
||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||
|
||||
if FLAGS.use_gpu:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
|
||||
# load infer config
|
||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
||||
|
||||
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
||||
|
||||
os.chdir(cur_path)
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
import onnxruntime
|
||||
from pathlib import Path
|
||||
|
||||
from models.det_model.inference import PredictConfig, predict_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--onnx_file',
|
||||
type=str,
|
||||
help="onnx model file path",
|
||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
||||
)
|
||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||
parser.add_argument(
|
||||
'--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
|
||||
)
|
||||
|
||||
|
||||
def get_test_images(infer_dir, infer_img):
|
||||
"""
|
||||
Get image path list in TEST mode
|
||||
"""
|
||||
assert (
|
||||
infer_img is not None or infer_dir is not None
|
||||
), "--image_file or --image_dir should be set"
|
||||
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
|
||||
assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
|
||||
|
||||
# infer_img has a higher priority
|
||||
if infer_img and os.path.isfile(infer_img):
|
||||
return [infer_img]
|
||||
|
||||
images = set()
|
||||
infer_dir = os.path.abspath(infer_dir)
|
||||
assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
|
||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||
exts += [ext.upper() for ext in exts]
|
||||
for ext in exts:
|
||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
||||
images = list(images)
|
||||
|
||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
||||
print("Found {} inference images in total.".format(len(images)))
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
print(f"Downloading {filename}...")
|
||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
||||
print("Download complete.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cur_path = os.getcwd()
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if not os.path.exists(FLAGS.infer_cfg):
|
||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
||||
|
||||
if not os.path.exists(FLAGS.onnx_file):
|
||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
||||
|
||||
# load image list
|
||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||
|
||||
if FLAGS.use_gpu:
|
||||
predictor = onnxruntime.InferenceSession(
|
||||
FLAGS.onnx_file, providers=['CUDAExecutionProvider']
|
||||
)
|
||||
else:
|
||||
predictor = onnxruntime.InferenceSession(
|
||||
FLAGS.onnx_file, providers=['CPUExecutionProvider']
|
||||
)
|
||||
# load infer config
|
||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
||||
|
||||
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
||||
|
||||
os.chdir(cur_path)
|
||||
@@ -18,32 +18,20 @@ from models.det_model.inference import PredictConfig
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-img', type=str, required=True, help='path to the input image')
|
||||
parser.add_argument(
|
||||
'-img',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to the input image'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inference-mode',
|
||||
'--inference-mode',
|
||||
type=str,
|
||||
default='cpu',
|
||||
help='Inference mode, select one of cpu, cuda, or mps'
|
||||
help='Inference mode, select one of cpu, cuda, or mps',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num-beam',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of beam search for decoding'
|
||||
'--num-beam', type=int, default=1, help='number of beam search for decoding'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-mix',
|
||||
action='store_true',
|
||||
help='use mix mode'
|
||||
)
|
||||
|
||||
parser.add_argument('-mix', action='store_true', help='use mix mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# You can use your own checkpoint and tokenizer path.
|
||||
print('Loading model and tokenizer...')
|
||||
latex_rec_model = TexTeller.from_pretrained()
|
||||
@@ -63,8 +51,8 @@ if __name__ == '__main__':
|
||||
|
||||
use_gpu = args.inference_mode == 'cuda'
|
||||
SIZE_LIMIT = 20 * 1024 * 1024
|
||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||
det_use_gpu = False
|
||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||
@@ -78,8 +66,16 @@ if __name__ == '__main__':
|
||||
detector = predict_det.TextDetector(paddleocr_args)
|
||||
paddleocr_args.use_gpu = rec_use_gpu
|
||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||
|
||||
|
||||
lang_ocr_models = [detector, recognizer]
|
||||
latex_rec_models = [latex_rec_model, tokenizer]
|
||||
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
||||
res = mix_inference(
|
||||
img_path,
|
||||
infer_config,
|
||||
latex_det_model,
|
||||
lang_ocr_models,
|
||||
latex_rec_models,
|
||||
args.inference_mode,
|
||||
args.num_beam,
|
||||
)
|
||||
print(res)
|
||||
BIN
texteller/models/__pycache__/globals.cpython-310.pyc
Normal file
@@ -9,7 +9,7 @@ class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = int(x)
|
||||
self.y = int(y)
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Point(x={self.x}, y={self.y})"
|
||||
|
||||
@@ -28,30 +28,28 @@ class Bbox:
|
||||
@property
|
||||
def ul_point(self) -> Point:
|
||||
return self.p
|
||||
|
||||
|
||||
@property
|
||||
def ur_point(self) -> Point:
|
||||
return Point(self.p.x + self.w, self.p.y)
|
||||
|
||||
|
||||
@property
|
||||
def ll_point(self) -> Point:
|
||||
return Point(self.p.x, self.p.y + self.h)
|
||||
|
||||
|
||||
@property
|
||||
def lr_point(self) -> Point:
|
||||
return Point(self.p.x + self.w, self.p.y + self.h)
|
||||
|
||||
|
||||
|
||||
def same_row(self, other) -> bool:
|
||||
if (
|
||||
(self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y)
|
||||
or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
|
||||
if (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) or (
|
||||
self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y
|
||||
):
|
||||
return True
|
||||
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
|
||||
return False
|
||||
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
||||
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
'''
|
||||
from top to bottom, from left to right
|
||||
@@ -60,7 +58,7 @@ class Bbox:
|
||||
return self.p.y < other.p.y
|
||||
else:
|
||||
return self.p.x < other.p.x
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
||||
|
||||
@@ -76,16 +74,16 @@ def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"
|
||||
top = bbox.p.y
|
||||
right = bbox.p.x + bbox.w
|
||||
bottom = bbox.p.y + bbox.h
|
||||
|
||||
|
||||
# Draw the rectangle on the image
|
||||
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
||||
|
||||
|
||||
# Optionally, add text label if it exists
|
||||
if bbox.label:
|
||||
drawer.text((left, top), bbox.label, fill="blue")
|
||||
|
||||
|
||||
if bbox.content:
|
||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||
|
||||
# Save the image with drawn rectangles
|
||||
img.save(log_dir / name)
|
||||
img.save(log_dir / name)
|
||||
BIN
texteller/models/det_model/__pycache__/Bbox.cpython-310.pyc
Normal file
BIN
texteller/models/det_model/__pycache__/inference.cpython-310.pyc
Normal file
@@ -12,10 +12,28 @@ from .Bbox import Bbox
|
||||
|
||||
# Global dictionary
|
||||
SUPPORT_MODELS = {
|
||||
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
|
||||
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
|
||||
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
|
||||
'DETR'
|
||||
'YOLO',
|
||||
'PPYOLOE',
|
||||
'RCNN',
|
||||
'SSD',
|
||||
'Face',
|
||||
'FCOS',
|
||||
'SOLOv2',
|
||||
'TTFNet',
|
||||
'S2ANet',
|
||||
'JDE',
|
||||
'FairMOT',
|
||||
'DeepSORT',
|
||||
'GFL',
|
||||
'PicoDet',
|
||||
'CenterNet',
|
||||
'TOOD',
|
||||
'RetinaNet',
|
||||
'StrongBaseline',
|
||||
'STGCN',
|
||||
'YOLOX',
|
||||
'HRNet',
|
||||
'DETR',
|
||||
}
|
||||
|
||||
|
||||
@@ -42,12 +60,12 @@ class PredictConfig(object):
|
||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
||||
|
||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
|
||||
self.colors = {
|
||||
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
|
||||
}
|
||||
|
||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||
print(
|
||||
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
|
||||
)
|
||||
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
|
||||
self.print_config()
|
||||
|
||||
def check_model(self, yml_conf):
|
||||
@@ -58,8 +76,7 @@ class PredictConfig(object):
|
||||
for support_model in SUPPORT_MODELS:
|
||||
if support_model in yml_conf['arch']:
|
||||
return True
|
||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
|
||||
'arch'], SUPPORT_MODELS))
|
||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
|
||||
|
||||
def print_config(self):
|
||||
print('----------- Model Configuration -----------')
|
||||
@@ -77,8 +94,15 @@ def draw_bbox(image, outputs, infer_config):
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
color = infer_config.colors[label]
|
||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
||||
cv2.putText(image, "{}: {:.2f}".format(label, score),
|
||||
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
cv2.putText(
|
||||
image,
|
||||
"{}: {:.2f}".format(label, score),
|
||||
(int(xmin), int(ymin - 5)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
2,
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
@@ -104,7 +128,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
@@ -119,7 +143,9 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
else:
|
||||
total_time += inference_time
|
||||
num_images += 1
|
||||
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds")
|
||||
print(
|
||||
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
|
||||
)
|
||||
|
||||
print("ONNXRuntime predict: ")
|
||||
if infer_config.arch in ["HRNet"]:
|
||||
@@ -128,8 +154,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
bboxes = np.array(outputs[0])
|
||||
for bbox in bboxes:
|
||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
||||
print(f"{int(bbox[0])} {bbox[1]} "
|
||||
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
||||
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
||||
|
||||
# Save the subimages (crop from the original image)
|
||||
subimg_counter = 1
|
||||
@@ -137,7 +162,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
|
||||
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
|
||||
if len(subimg) == 0:
|
||||
continue
|
||||
|
||||
@@ -151,8 +176,14 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
for output in np.array(outputs[0]):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白
|
||||
|
||||
cv2.rectangle(
|
||||
img_with_mask,
|
||||
(int(xmin), int(ymin)),
|
||||
(int(xmax), int(ymax)),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
) # 盖白
|
||||
|
||||
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
||||
|
||||
output_dir = imgsave_dir
|
||||
@@ -178,7 +209,7 @@ def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||
res = []
|
||||
@@ -15,10 +15,8 @@ def decode_image(img_path):
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
img_info = {
|
||||
"im_shape": np.array(
|
||||
im.shape[:2], dtype=np.float32),
|
||||
"scale_factor": np.array(
|
||||
[1., 1.], dtype=np.float32)
|
||||
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
||||
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
||||
}
|
||||
return im, img_info
|
||||
|
||||
@@ -51,16 +49,9 @@ class Resize(object):
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
@@ -134,7 +125,9 @@ class Permute(object):
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
@@ -151,7 +144,7 @@ class Permute(object):
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
@@ -198,18 +191,16 @@ class LetterBoxResize(object):
|
||||
ratio_h = float(height) / shape[0]
|
||||
ratio_w = float(width) / shape[1]
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
new_shape = (round(shape[1] * ratio),
|
||||
round(shape[0] * ratio)) # [width, height]
|
||||
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
|
||||
padw = (width - new_shape[0]) / 2
|
||||
padh = (height - new_shape[1]) / 2
|
||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
||||
|
||||
img = cv2.resize(
|
||||
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=color) # padded rectangular
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||||
) # padded rectangular
|
||||
return img, ratio, padw, padh
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
@@ -302,12 +293,7 @@ def _get_3rd_point(a, b):
|
||||
return third_pt
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
input_size,
|
||||
rot,
|
||||
output_size,
|
||||
shift=(0., 0.),
|
||||
inv=False):
|
||||
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
|
||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
||||
|
||||
Args:
|
||||
@@ -337,8 +323,8 @@ def get_affine_transform(center,
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0., dst_w * -0.5])
|
||||
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0.0, dst_w * -0.5])
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
@@ -359,16 +345,9 @@ def get_affine_transform(center,
|
||||
|
||||
|
||||
class WarpAffine(object):
|
||||
"""Warp affine the image
|
||||
"""
|
||||
"""Warp affine the image"""
|
||||
|
||||
def __init__(self,
|
||||
keep_res=False,
|
||||
pad=31,
|
||||
input_h=512,
|
||||
input_w=512,
|
||||
scale=0.4,
|
||||
shift=0.1):
|
||||
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
|
||||
self.keep_res = keep_res
|
||||
self.pad = pad
|
||||
self.input_h = input_h
|
||||
@@ -398,12 +377,11 @@ class WarpAffine(object):
|
||||
else:
|
||||
s = max(h, w) * 1.0
|
||||
input_h, input_w = self.input_h, self.input_w
|
||||
c = np.array([w / 2., h / 2.], dtype=np.float32)
|
||||
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
|
||||
|
||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
||||
img = cv2.resize(img, (w, h))
|
||||
inp = cv2.warpAffine(
|
||||
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||
return inp, im_info
|
||||
|
||||
|
||||
@@ -432,13 +410,17 @@ def get_warp_matrix(theta, size_input, size_dst, size_target):
|
||||
matrix[0, 0] = np.cos(theta) * scale_x
|
||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
||||
matrix[0, 2] = scale_x * (
|
||||
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
|
||||
np.sin(theta) + 0.5 * size_target[0])
|
||||
-0.5 * size_input[0] * np.cos(theta)
|
||||
+ 0.5 * size_input[1] * np.sin(theta)
|
||||
+ 0.5 * size_target[0]
|
||||
)
|
||||
matrix[1, 0] = np.sin(theta) * scale_y
|
||||
matrix[1, 1] = np.cos(theta) * scale_y
|
||||
matrix[1, 2] = scale_y * (
|
||||
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
|
||||
np.cos(theta) + 0.5 * size_target[1])
|
||||
-0.5 * size_input[0] * np.sin(theta)
|
||||
- 0.5 * size_input[1] * np.cos(theta)
|
||||
+ 0.5 * size_target[1]
|
||||
)
|
||||
return matrix
|
||||
|
||||
|
||||
@@ -462,22 +444,26 @@ class TopDownEvalAffine(object):
|
||||
def __call__(self, image, im_info):
|
||||
rot = 0
|
||||
imshape = im_info['im_shape'][::-1]
|
||||
center = im_info['center'] if 'center' in im_info else imshape / 2.
|
||||
center = im_info['center'] if 'center' in im_info else imshape / 2.0
|
||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
||||
if self.use_udp:
|
||||
trans = get_warp_matrix(
|
||||
rot, center * 2.0,
|
||||
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
|
||||
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
|
||||
)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
trans,
|
||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
)
|
||||
else:
|
||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
trans,
|
||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
return image, im_info
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Formula image(grayscale) mean and variance
|
||||
IMAGE_MEAN = 0.9545467
|
||||
IMAGE_STD = 0.15394445
|
||||
IMAGE_STD = 0.15394445
|
||||
|
||||
# Vocabulary size for TexTeller
|
||||
VOCAB_SIZE = 15000
|
||||
@@ -20,4 +20,4 @@ MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# Minimum height and width for input image for TexTeller
|
||||
MIN_HEIGHT = 12
|
||||
MIN_WIDTH = 30
|
||||
MIN_WIDTH = 30
|
||||
@@ -1,30 +1,24 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ...globals import (
|
||||
VOCAB_SIZE,
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
MAX_TOKEN_SIZE
|
||||
)
|
||||
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
||||
|
||||
from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
VisionEncoderDecoderModel,
|
||||
VisionEncoderDecoderConfig
|
||||
)
|
||||
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(
|
||||
Path(__file__).resolve().parent / "config.json"
|
||||
)
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||
if model_path is None or model_path == 'default':
|
||||
@@ -32,8 +26,12 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
||||
return ORTModelForVision2Seq.from_pretrained(
|
||||
cls.REPO_NAME,
|
||||
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
||||
)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
BIN
texteller/models/ocr_model/train/augraphy_cache/image_0.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_1.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_10.png
Normal file
|
After Width: | Height: | Size: 4.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_11.png
Normal file
|
After Width: | Height: | Size: 8.5 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_12.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_13.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_14.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_15.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_16.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_17.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_18.png
Normal file
|
After Width: | Height: | Size: 5.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_19.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_2.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_20.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_21.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_22.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_23.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_24.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_25.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_26.png
Normal file
|
After Width: | Height: | Size: 9.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_27.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_28.png
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_29.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_4.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_6.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_7.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_8.png
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_9.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 8.7 KiB After Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 2.8 KiB After Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB After Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 3.7 KiB After Width: | Height: | Size: 3.7 KiB |
|
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB After Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 4.9 KiB After Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 1.8 KiB After Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.7 KiB After Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 4.8 KiB After Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
@@ -0,0 +1,35 @@
|
||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
@@ -5,18 +5,24 @@ from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
|
||||
from ..utils.functional import (
|
||||
tokenize_fn,
|
||||
collate_fn,
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
filter_fn,
|
||||
)
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
@@ -24,11 +30,9 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
@@ -52,43 +56,44 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
||||
dataset = dataset.filter(
|
||||
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
||||
)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(
|
||||
filter_fn_with_tokenizer,
|
||||
num_proc=8
|
||||
)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||
tokenized_dataset = dataset.map(
|
||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||
)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
@@ -96,14 +101,14 @@ if __name__ == '__main__':
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
#+e.g.
|
||||
#+model = TexTeller.from_pretrained(
|
||||
#+ '/path/to/your/model_checkpoint'
|
||||
#+)
|
||||
# +e.g.
|
||||
# +model = TexTeller.from_pretrained(
|
||||
# + '/path/to/your/model_checkpoint'
|
||||
# +)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
31
texteller/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,31 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
@@ -26,7 +26,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
@@ -54,6 +54,7 @@ def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
||||
sample['image'].height > MIN_HEIGHT
|
||||
and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
@@ -12,7 +12,7 @@ def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
continue
|
||||
if image.dtype == np.uint16:
|
||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||