[chore] exclude paddleocr directory from pre-commit hooks
@@ -4,8 +4,10 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||||
|
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
args: [--config=pyproject.toml]
|
args: [--config=pyproject.toml]
|
||||||
|
exclude: ^texteller/models/thrid_party/paddleocr/
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v4.5.0
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
|
||||||
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
|
||||||
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
|
||||||
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
|
||||||
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
|
||||||
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
|
||||||
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
|
||||||
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
|
||||||
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
|
||||||
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
|
||||||
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
|
||||||
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
|
||||||
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
|
||||||
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
|
||||||
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
|
||||||
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
|
||||||
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
|
||||||
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
|
||||||
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
|
||||||
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
|
||||||
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
|
||||||
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
|
||||||
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
|
||||||
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
|
||||||
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
|
||||||
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
|
||||||
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
|
||||||
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
|
||||||
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
|
||||||
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
|
||||||
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
|
||||||
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
|
||||||
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
|
||||||
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
|
||||||
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
from PIL import Image
|
|
||||||
from pathlib import Path
|
|
||||||
import datasets
|
|
||||||
import json
|
|
||||||
|
|
||||||
DIR_URL = Path('absolute/path/to/dataset/directory')
|
|
||||||
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
|
|
||||||
|
|
||||||
|
|
||||||
class LatexFormulas(datasets.GeneratorBasedBuilder):
|
|
||||||
BUILDER_CONFIGS = []
|
|
||||||
|
|
||||||
def _info(self):
|
|
||||||
return datasets.DatasetInfo(
|
|
||||||
features=datasets.Features({
|
|
||||||
"image": datasets.Image(),
|
|
||||||
"latex_formula": datasets.Value("string")
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
|
||||||
dir_path = Path(dl_manager.download(str(DIR_URL)))
|
|
||||||
assert dir_path.is_dir()
|
|
||||||
|
|
||||||
return [
|
|
||||||
datasets.SplitGenerator(
|
|
||||||
name=datasets.Split.TRAIN,
|
|
||||||
gen_kwargs={
|
|
||||||
'dir_path': dir_path,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _generate_examples(self, dir_path: Path):
|
|
||||||
images_path = dir_path / 'images'
|
|
||||||
formulas_path = dir_path / 'formulas.jsonl'
|
|
||||||
|
|
||||||
img2formula = {}
|
|
||||||
with formulas_path.open('r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
single_json = json.loads(line)
|
|
||||||
img2formula[single_json['img_name']] = single_json['formula']
|
|
||||||
|
|
||||||
for img_path in images_path.iterdir():
|
|
||||||
if img_path.suffix not in ['.jpg', '.png']:
|
|
||||||
continue
|
|
||||||
yield str(img_path), {
|
|
||||||
"image": Image.open(img_path),
|
|
||||||
"latex_formula": img2formula[img_path.name]
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
CONFIG = {
|
|
||||||
"seed": 42, # Random seed for reproducibility
|
|
||||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
|
||||||
"learning_rate": 5e-5, # Learning rate
|
|
||||||
"num_train_epochs": 10, # Total number of training epochs
|
|
||||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
|
||||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
|
||||||
|
|
||||||
"output_dir": "train_result", # Output directory
|
|
||||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
|
||||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
|
||||||
|
|
||||||
"save_strategy": "steps", # Strategy to save checkpoints
|
|
||||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
|
||||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
|
||||||
|
|
||||||
"logging_strategy": "steps", # Log every certain number of steps
|
|
||||||
"logging_steps": 500, # Number of steps between each log
|
|
||||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
|
||||||
|
|
||||||
"optim": "adamw_torch", # Optimizer
|
|
||||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
|
||||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
|
||||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
|
||||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
|
||||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
|
||||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
|
||||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
|
||||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
|
||||||
|
|
||||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
|
||||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
|
||||||
|
|
||||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
|
||||||
"eval_steps": 500, # If evaluation_strategy="step"
|
|
||||||
|
|
||||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .mix_inference import mix_inference
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
import os
|
|
||||||
import argparse
|
|
||||||
import cv2 as cv
|
|
||||||
from pathlib import Path
|
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
|
||||||
from models.ocr_model.utils.inference import inference as latex_inference
|
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os.chdir(Path(__file__).resolve().parent)
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'-img_dir',
|
|
||||||
type=str,
|
|
||||||
help='path to the input image',
|
|
||||||
default='./detect_results/subimages'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-output_dir',
|
|
||||||
type=str,
|
|
||||||
help='path to the output dir',
|
|
||||||
default='./rec_results'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--inference-mode',
|
|
||||||
type=str,
|
|
||||||
default='cpu',
|
|
||||||
help='Inference mode, select one of cpu, cuda, or mps'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--num-beam',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='number of beam search for decoding'
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print('Loading model and tokenizer...')
|
|
||||||
latex_rec_model = TexTeller.from_pretrained()
|
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
|
||||||
print('Model and tokenizer loaded.')
|
|
||||||
|
|
||||||
# Create the output directory if it doesn't exist
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Loop through all images in the input directory
|
|
||||||
for filename in os.listdir(args.img_dir):
|
|
||||||
img_path = os.path.join(args.img_dir, filename)
|
|
||||||
img = cv.imread(img_path)
|
|
||||||
|
|
||||||
if img is not None:
|
|
||||||
print(f'Inference for {filename}...')
|
|
||||||
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
|
|
||||||
res = to_katex(res[0])
|
|
||||||
|
|
||||||
# Save the recognition result to a text file
|
|
||||||
output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
f.write(res)
|
|
||||||
|
|
||||||
print(f'Result saved to {output_file}')
|
|
||||||
else:
|
|
||||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
|
||||||
@@ -1,85 +1,96 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from models.det_model.inference import PredictConfig, predict_image
|
from models.det_model.inference import PredictConfig, predict_image
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
|
parser.add_argument(
|
||||||
default="./models/det_model/model/infer_cfg.yml")
|
"--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
|
||||||
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
|
)
|
||||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
parser.add_argument(
|
||||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
'--onnx_file',
|
||||||
parser.add_argument("--image_file", type=str)
|
type=str,
|
||||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
help="onnx model file path",
|
||||||
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True)
|
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
||||||
|
)
|
||||||
|
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||||
def get_test_images(infer_dir, infer_img):
|
parser.add_argument("--image_file", type=str)
|
||||||
"""
|
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||||
Get image path list in TEST mode
|
parser.add_argument(
|
||||||
"""
|
'--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
|
||||||
assert infer_img is not None or infer_dir is not None, \
|
)
|
||||||
"--image_file or --image_dir should be set"
|
|
||||||
assert infer_img is None or os.path.isfile(infer_img), \
|
|
||||||
"{} is not a file".format(infer_img)
|
def get_test_images(infer_dir, infer_img):
|
||||||
assert infer_dir is None or os.path.isdir(infer_dir), \
|
"""
|
||||||
"{} is not a directory".format(infer_dir)
|
Get image path list in TEST mode
|
||||||
|
"""
|
||||||
# infer_img has a higher priority
|
assert (
|
||||||
if infer_img and os.path.isfile(infer_img):
|
infer_img is not None or infer_dir is not None
|
||||||
return [infer_img]
|
), "--image_file or --image_dir should be set"
|
||||||
|
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
|
||||||
images = set()
|
assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
|
||||||
infer_dir = os.path.abspath(infer_dir)
|
|
||||||
assert os.path.isdir(infer_dir), \
|
# infer_img has a higher priority
|
||||||
"infer_dir {} is not a directory".format(infer_dir)
|
if infer_img and os.path.isfile(infer_img):
|
||||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
return [infer_img]
|
||||||
exts += [ext.upper() for ext in exts]
|
|
||||||
for ext in exts:
|
images = set()
|
||||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
infer_dir = os.path.abspath(infer_dir)
|
||||||
images = list(images)
|
assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
|
||||||
|
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
exts += [ext.upper() for ext in exts]
|
||||||
print("Found {} inference images in total.".format(len(images)))
|
for ext in exts:
|
||||||
|
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
||||||
return images
|
images = list(images)
|
||||||
|
|
||||||
def download_file(url, filename):
|
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
||||||
print(f"Downloading {filename}...")
|
print("Found {} inference images in total.".format(len(images)))
|
||||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
|
||||||
print("Download complete.")
|
return images
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
cur_path = os.getcwd()
|
def download_file(url, filename):
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
print(f"Downloading {filename}...")
|
||||||
os.chdir(script_dirpath)
|
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
||||||
|
print("Download complete.")
|
||||||
FLAGS = parser.parse_args()
|
|
||||||
|
|
||||||
if not os.path.exists(FLAGS.infer_cfg):
|
if __name__ == '__main__':
|
||||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
cur_path = os.getcwd()
|
||||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
|
os.chdir(script_dirpath)
|
||||||
if not os.path.exists(FLAGS.onnx_file):
|
|
||||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
FLAGS = parser.parse_args()
|
||||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
|
||||||
|
if not os.path.exists(FLAGS.infer_cfg):
|
||||||
# load image list
|
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
||||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
||||||
|
|
||||||
if FLAGS.use_gpu:
|
if not os.path.exists(FLAGS.onnx_file):
|
||||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider'])
|
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
||||||
else:
|
download_file(onnx_file_url, FLAGS.onnx_file)
|
||||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
|
|
||||||
# load infer config
|
# load image list
|
||||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||||
|
|
||||||
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
if FLAGS.use_gpu:
|
||||||
|
predictor = onnxruntime.InferenceSession(
|
||||||
os.chdir(cur_path)
|
FLAGS.onnx_file, providers=['CUDAExecutionProvider']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
predictor = onnxruntime.InferenceSession(
|
||||||
|
FLAGS.onnx_file, providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
# load infer config
|
||||||
|
infer_config = PredictConfig(FLAGS.infer_cfg)
|
||||||
|
|
||||||
|
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
||||||
|
|
||||||
|
os.chdir(cur_path)
|
||||||
@@ -18,32 +18,20 @@ from models.det_model.inference import PredictConfig
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
os.chdir(Path(__file__).resolve().parent)
|
os.chdir(Path(__file__).resolve().parent)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-img', type=str, required=True, help='path to the input image')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-img',
|
'--inference-mode',
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='path to the input image'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--inference-mode',
|
|
||||||
type=str,
|
type=str,
|
||||||
default='cpu',
|
default='cpu',
|
||||||
help='Inference mode, select one of cpu, cuda, or mps'
|
help='Inference mode, select one of cpu, cuda, or mps',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num-beam',
|
'--num-beam', type=int, default=1, help='number of beam search for decoding'
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='number of beam search for decoding'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument('-mix', action='store_true', help='use mix mode')
|
||||||
'-mix',
|
|
||||||
action='store_true',
|
|
||||||
help='use mix mode'
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# You can use your own checkpoint and tokenizer path.
|
# You can use your own checkpoint and tokenizer path.
|
||||||
print('Loading model and tokenizer...')
|
print('Loading model and tokenizer...')
|
||||||
latex_rec_model = TexTeller.from_pretrained()
|
latex_rec_model = TexTeller.from_pretrained()
|
||||||
@@ -63,8 +51,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
use_gpu = args.inference_mode == 'cuda'
|
use_gpu = args.inference_mode == 'cuda'
|
||||||
SIZE_LIMIT = 20 * 1024 * 1024
|
SIZE_LIMIT = 20 * 1024 * 1024
|
||||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||||
det_use_gpu = False
|
det_use_gpu = False
|
||||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||||
@@ -78,8 +66,16 @@ if __name__ == '__main__':
|
|||||||
detector = predict_det.TextDetector(paddleocr_args)
|
detector = predict_det.TextDetector(paddleocr_args)
|
||||||
paddleocr_args.use_gpu = rec_use_gpu
|
paddleocr_args.use_gpu = rec_use_gpu
|
||||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||||
|
|
||||||
lang_ocr_models = [detector, recognizer]
|
lang_ocr_models = [detector, recognizer]
|
||||||
latex_rec_models = [latex_rec_model, tokenizer]
|
latex_rec_models = [latex_rec_model, tokenizer]
|
||||||
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
res = mix_inference(
|
||||||
|
img_path,
|
||||||
|
infer_config,
|
||||||
|
latex_det_model,
|
||||||
|
lang_ocr_models,
|
||||||
|
latex_rec_models,
|
||||||
|
args.inference_mode,
|
||||||
|
args.num_beam,
|
||||||
|
)
|
||||||
print(res)
|
print(res)
|
||||||
BIN
texteller/models/__pycache__/globals.cpython-310.pyc
Normal file
@@ -9,7 +9,7 @@ class Point:
|
|||||||
def __init__(self, x: int, y: int):
|
def __init__(self, x: int, y: int):
|
||||||
self.x = int(x)
|
self.x = int(x)
|
||||||
self.y = int(y)
|
self.y = int(y)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Point(x={self.x}, y={self.y})"
|
return f"Point(x={self.x}, y={self.y})"
|
||||||
|
|
||||||
@@ -28,30 +28,28 @@ class Bbox:
|
|||||||
@property
|
@property
|
||||||
def ul_point(self) -> Point:
|
def ul_point(self) -> Point:
|
||||||
return self.p
|
return self.p
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ur_point(self) -> Point:
|
def ur_point(self) -> Point:
|
||||||
return Point(self.p.x + self.w, self.p.y)
|
return Point(self.p.x + self.w, self.p.y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ll_point(self) -> Point:
|
def ll_point(self) -> Point:
|
||||||
return Point(self.p.x, self.p.y + self.h)
|
return Point(self.p.x, self.p.y + self.h)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lr_point(self) -> Point:
|
def lr_point(self) -> Point:
|
||||||
return Point(self.p.x + self.w, self.p.y + self.h)
|
return Point(self.p.x + self.w, self.p.y + self.h)
|
||||||
|
|
||||||
|
|
||||||
def same_row(self, other) -> bool:
|
def same_row(self, other) -> bool:
|
||||||
if (
|
if (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) or (
|
||||||
(self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y)
|
self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y
|
||||||
or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
|
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
|
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
|
||||||
return False
|
return False
|
||||||
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
||||||
|
|
||||||
def __lt__(self, other) -> bool:
|
def __lt__(self, other) -> bool:
|
||||||
'''
|
'''
|
||||||
from top to bottom, from left to right
|
from top to bottom, from left to right
|
||||||
@@ -60,7 +58,7 @@ class Bbox:
|
|||||||
return self.p.y < other.p.y
|
return self.p.y < other.p.y
|
||||||
else:
|
else:
|
||||||
return self.p.x < other.p.x
|
return self.p.x < other.p.x
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
||||||
|
|
||||||
@@ -76,16 +74,16 @@ def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"
|
|||||||
top = bbox.p.y
|
top = bbox.p.y
|
||||||
right = bbox.p.x + bbox.w
|
right = bbox.p.x + bbox.w
|
||||||
bottom = bbox.p.y + bbox.h
|
bottom = bbox.p.y + bbox.h
|
||||||
|
|
||||||
# Draw the rectangle on the image
|
# Draw the rectangle on the image
|
||||||
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
||||||
|
|
||||||
# Optionally, add text label if it exists
|
# Optionally, add text label if it exists
|
||||||
if bbox.label:
|
if bbox.label:
|
||||||
drawer.text((left, top), bbox.label, fill="blue")
|
drawer.text((left, top), bbox.label, fill="blue")
|
||||||
|
|
||||||
if bbox.content:
|
if bbox.content:
|
||||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||||
|
|
||||||
# Save the image with drawn rectangles
|
# Save the image with drawn rectangles
|
||||||
img.save(log_dir / name)
|
img.save(log_dir / name)
|
||||||
BIN
texteller/models/det_model/__pycache__/Bbox.cpython-310.pyc
Normal file
BIN
texteller/models/det_model/__pycache__/inference.cpython-310.pyc
Normal file
@@ -12,10 +12,28 @@ from .Bbox import Bbox
|
|||||||
|
|
||||||
# Global dictionary
|
# Global dictionary
|
||||||
SUPPORT_MODELS = {
|
SUPPORT_MODELS = {
|
||||||
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
|
'YOLO',
|
||||||
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
|
'PPYOLOE',
|
||||||
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
|
'RCNN',
|
||||||
'DETR'
|
'SSD',
|
||||||
|
'Face',
|
||||||
|
'FCOS',
|
||||||
|
'SOLOv2',
|
||||||
|
'TTFNet',
|
||||||
|
'S2ANet',
|
||||||
|
'JDE',
|
||||||
|
'FairMOT',
|
||||||
|
'DeepSORT',
|
||||||
|
'GFL',
|
||||||
|
'PicoDet',
|
||||||
|
'CenterNet',
|
||||||
|
'TOOD',
|
||||||
|
'RetinaNet',
|
||||||
|
'StrongBaseline',
|
||||||
|
'STGCN',
|
||||||
|
'YOLOX',
|
||||||
|
'HRNet',
|
||||||
|
'DETR',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -42,12 +60,12 @@ class PredictConfig(object):
|
|||||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
||||||
|
|
||||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||||
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
|
self.colors = {
|
||||||
|
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
|
||||||
|
}
|
||||||
|
|
||||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||||
print(
|
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
|
||||||
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
|
|
||||||
)
|
|
||||||
self.print_config()
|
self.print_config()
|
||||||
|
|
||||||
def check_model(self, yml_conf):
|
def check_model(self, yml_conf):
|
||||||
@@ -58,8 +76,7 @@ class PredictConfig(object):
|
|||||||
for support_model in SUPPORT_MODELS:
|
for support_model in SUPPORT_MODELS:
|
||||||
if support_model in yml_conf['arch']:
|
if support_model in yml_conf['arch']:
|
||||||
return True
|
return True
|
||||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
|
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
|
||||||
'arch'], SUPPORT_MODELS))
|
|
||||||
|
|
||||||
def print_config(self):
|
def print_config(self):
|
||||||
print('----------- Model Configuration -----------')
|
print('----------- Model Configuration -----------')
|
||||||
@@ -77,8 +94,15 @@ def draw_bbox(image, outputs, infer_config):
|
|||||||
label = infer_config.label_list[int(cls_id)]
|
label = infer_config.label_list[int(cls_id)]
|
||||||
color = infer_config.colors[label]
|
color = infer_config.colors[label]
|
||||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
||||||
cv2.putText(image, "{}: {:.2f}".format(label, score),
|
cv2.putText(
|
||||||
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
image,
|
||||||
|
"{}: {:.2f}".format(label, score),
|
||||||
|
(int(xmin), int(ymin - 5)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
color,
|
||||||
|
2,
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
@@ -104,7 +128,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|||||||
|
|
||||||
inputs = transforms(img_path)
|
inputs = transforms(img_path)
|
||||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||||
|
|
||||||
# Start timing
|
# Start timing
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -119,7 +143,9 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|||||||
else:
|
else:
|
||||||
total_time += inference_time
|
total_time += inference_time
|
||||||
num_images += 1
|
num_images += 1
|
||||||
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds")
|
print(
|
||||||
|
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
print("ONNXRuntime predict: ")
|
print("ONNXRuntime predict: ")
|
||||||
if infer_config.arch in ["HRNet"]:
|
if infer_config.arch in ["HRNet"]:
|
||||||
@@ -128,8 +154,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|||||||
bboxes = np.array(outputs[0])
|
bboxes = np.array(outputs[0])
|
||||||
for bbox in bboxes:
|
for bbox in bboxes:
|
||||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
||||||
print(f"{int(bbox[0])} {bbox[1]} "
|
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
||||||
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
|
||||||
|
|
||||||
# Save the subimages (crop from the original image)
|
# Save the subimages (crop from the original image)
|
||||||
subimg_counter = 1
|
subimg_counter = 1
|
||||||
@@ -137,7 +162,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|||||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||||
if score > infer_config.draw_threshold:
|
if score > infer_config.draw_threshold:
|
||||||
label = infer_config.label_list[int(cls_id)]
|
label = infer_config.label_list[int(cls_id)]
|
||||||
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
|
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
|
||||||
if len(subimg) == 0:
|
if len(subimg) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -151,8 +176,14 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|||||||
for output in np.array(outputs[0]):
|
for output in np.array(outputs[0]):
|
||||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||||
if score > infer_config.draw_threshold:
|
if score > infer_config.draw_threshold:
|
||||||
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白
|
cv2.rectangle(
|
||||||
|
img_with_mask,
|
||||||
|
(int(xmin), int(ymin)),
|
||||||
|
(int(xmax), int(ymax)),
|
||||||
|
(255, 255, 255),
|
||||||
|
-1,
|
||||||
|
) # 盖白
|
||||||
|
|
||||||
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
||||||
|
|
||||||
output_dir = imgsave_dir
|
output_dir = imgsave_dir
|
||||||
@@ -178,7 +209,7 @@ def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
|
|||||||
transforms = Compose(infer_config.preprocess_infos)
|
transforms = Compose(infer_config.preprocess_infos)
|
||||||
inputs = transforms(img_path)
|
inputs = transforms(img_path)
|
||||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||||
inputs = {k: inputs[k][None, ] for k in inputs_name}
|
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||||
|
|
||||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||||
res = []
|
res = []
|
||||||
@@ -15,10 +15,8 @@ def decode_image(img_path):
|
|||||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||||
img_info = {
|
img_info = {
|
||||||
"im_shape": np.array(
|
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
||||||
im.shape[:2], dtype=np.float32),
|
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
||||||
"scale_factor": np.array(
|
|
||||||
[1., 1.], dtype=np.float32)
|
|
||||||
}
|
}
|
||||||
return im, img_info
|
return im, img_info
|
||||||
|
|
||||||
@@ -51,16 +49,9 @@ class Resize(object):
|
|||||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||||
im_channel = im.shape[2]
|
im_channel = im.shape[2]
|
||||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||||
im = cv2.resize(
|
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
||||||
im,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
fx=im_scale_x,
|
|
||||||
fy=im_scale_y,
|
|
||||||
interpolation=self.interp)
|
|
||||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||||
im_info['scale_factor'] = np.array(
|
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
|
||||||
[im_scale_y, im_scale_x]).astype('float32')
|
|
||||||
return im, im_info
|
return im, im_info
|
||||||
|
|
||||||
def generate_scale(self, im):
|
def generate_scale(self, im):
|
||||||
@@ -134,7 +125,9 @@ class Permute(object):
|
|||||||
channel_first (bool): whether convert HWC to CHW
|
channel_first (bool): whether convert HWC to CHW
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ):
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
super(Permute, self).__init__()
|
super(Permute, self).__init__()
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
def __call__(self, im, im_info):
|
||||||
@@ -151,7 +144,7 @@ class Permute(object):
|
|||||||
|
|
||||||
|
|
||||||
class PadStride(object):
|
class PadStride(object):
|
||||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||||
Args:
|
Args:
|
||||||
stride (bool): model with FPN need image shape % stride == 0
|
stride (bool): model with FPN need image shape % stride == 0
|
||||||
"""
|
"""
|
||||||
@@ -198,18 +191,16 @@ class LetterBoxResize(object):
|
|||||||
ratio_h = float(height) / shape[0]
|
ratio_h = float(height) / shape[0]
|
||||||
ratio_w = float(width) / shape[1]
|
ratio_w = float(width) / shape[1]
|
||||||
ratio = min(ratio_h, ratio_w)
|
ratio = min(ratio_h, ratio_w)
|
||||||
new_shape = (round(shape[1] * ratio),
|
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
|
||||||
round(shape[0] * ratio)) # [width, height]
|
|
||||||
padw = (width - new_shape[0]) / 2
|
padw = (width - new_shape[0]) / 2
|
||||||
padh = (height - new_shape[1]) / 2
|
padh = (height - new_shape[1]) / 2
|
||||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
||||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
left, right = round(padw - 0.1), round(padw + 0.1)
|
||||||
|
|
||||||
img = cv2.resize(
|
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||||
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
|
||||||
img = cv2.copyMakeBorder(
|
img = cv2.copyMakeBorder(
|
||||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||||||
value=color) # padded rectangular
|
) # padded rectangular
|
||||||
return img, ratio, padw, padh
|
return img, ratio, padw, padh
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
def __call__(self, im, im_info):
|
||||||
@@ -302,12 +293,7 @@ def _get_3rd_point(a, b):
|
|||||||
return third_pt
|
return third_pt
|
||||||
|
|
||||||
|
|
||||||
def get_affine_transform(center,
|
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
|
||||||
input_size,
|
|
||||||
rot,
|
|
||||||
output_size,
|
|
||||||
shift=(0., 0.),
|
|
||||||
inv=False):
|
|
||||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -337,8 +323,8 @@ def get_affine_transform(center,
|
|||||||
dst_h = output_size[1]
|
dst_h = output_size[1]
|
||||||
|
|
||||||
rot_rad = np.pi * rot / 180
|
rot_rad = np.pi * rot / 180
|
||||||
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
|
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
|
||||||
dst_dir = np.array([0., dst_w * -0.5])
|
dst_dir = np.array([0.0, dst_w * -0.5])
|
||||||
|
|
||||||
src = np.zeros((3, 2), dtype=np.float32)
|
src = np.zeros((3, 2), dtype=np.float32)
|
||||||
src[0, :] = center + scale_tmp * shift
|
src[0, :] = center + scale_tmp * shift
|
||||||
@@ -359,16 +345,9 @@ def get_affine_transform(center,
|
|||||||
|
|
||||||
|
|
||||||
class WarpAffine(object):
|
class WarpAffine(object):
|
||||||
"""Warp affine the image
|
"""Warp affine the image"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
|
||||||
keep_res=False,
|
|
||||||
pad=31,
|
|
||||||
input_h=512,
|
|
||||||
input_w=512,
|
|
||||||
scale=0.4,
|
|
||||||
shift=0.1):
|
|
||||||
self.keep_res = keep_res
|
self.keep_res = keep_res
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.input_h = input_h
|
self.input_h = input_h
|
||||||
@@ -398,12 +377,11 @@ class WarpAffine(object):
|
|||||||
else:
|
else:
|
||||||
s = max(h, w) * 1.0
|
s = max(h, w) * 1.0
|
||||||
input_h, input_w = self.input_h, self.input_w
|
input_h, input_w = self.input_h, self.input_w
|
||||||
c = np.array([w / 2., h / 2.], dtype=np.float32)
|
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
|
||||||
|
|
||||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
||||||
img = cv2.resize(img, (w, h))
|
img = cv2.resize(img, (w, h))
|
||||||
inp = cv2.warpAffine(
|
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||||
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
|
||||||
return inp, im_info
|
return inp, im_info
|
||||||
|
|
||||||
|
|
||||||
@@ -432,13 +410,17 @@ def get_warp_matrix(theta, size_input, size_dst, size_target):
|
|||||||
matrix[0, 0] = np.cos(theta) * scale_x
|
matrix[0, 0] = np.cos(theta) * scale_x
|
||||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
matrix[0, 1] = -np.sin(theta) * scale_x
|
||||||
matrix[0, 2] = scale_x * (
|
matrix[0, 2] = scale_x * (
|
||||||
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
|
-0.5 * size_input[0] * np.cos(theta)
|
||||||
np.sin(theta) + 0.5 * size_target[0])
|
+ 0.5 * size_input[1] * np.sin(theta)
|
||||||
|
+ 0.5 * size_target[0]
|
||||||
|
)
|
||||||
matrix[1, 0] = np.sin(theta) * scale_y
|
matrix[1, 0] = np.sin(theta) * scale_y
|
||||||
matrix[1, 1] = np.cos(theta) * scale_y
|
matrix[1, 1] = np.cos(theta) * scale_y
|
||||||
matrix[1, 2] = scale_y * (
|
matrix[1, 2] = scale_y * (
|
||||||
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
|
-0.5 * size_input[0] * np.sin(theta)
|
||||||
np.cos(theta) + 0.5 * size_target[1])
|
- 0.5 * size_input[1] * np.cos(theta)
|
||||||
|
+ 0.5 * size_target[1]
|
||||||
|
)
|
||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
@@ -462,22 +444,26 @@ class TopDownEvalAffine(object):
|
|||||||
def __call__(self, image, im_info):
|
def __call__(self, image, im_info):
|
||||||
rot = 0
|
rot = 0
|
||||||
imshape = im_info['im_shape'][::-1]
|
imshape = im_info['im_shape'][::-1]
|
||||||
center = im_info['center'] if 'center' in im_info else imshape / 2.
|
center = im_info['center'] if 'center' in im_info else imshape / 2.0
|
||||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
scale = im_info['scale'] if 'scale' in im_info else imshape
|
||||||
if self.use_udp:
|
if self.use_udp:
|
||||||
trans = get_warp_matrix(
|
trans = get_warp_matrix(
|
||||||
rot, center * 2.0,
|
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
|
||||||
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
|
)
|
||||||
image = cv2.warpAffine(
|
image = cv2.warpAffine(
|
||||||
image,
|
image,
|
||||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
trans,
|
||||||
flags=cv2.INTER_LINEAR)
|
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
||||||
image = cv2.warpAffine(
|
image = cv2.warpAffine(
|
||||||
image,
|
image,
|
||||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
trans,
|
||||||
flags=cv2.INTER_LINEAR)
|
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
|
||||||
return image, im_info
|
return image, im_info
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Formula image(grayscale) mean and variance
|
# Formula image(grayscale) mean and variance
|
||||||
IMAGE_MEAN = 0.9545467
|
IMAGE_MEAN = 0.9545467
|
||||||
IMAGE_STD = 0.15394445
|
IMAGE_STD = 0.15394445
|
||||||
|
|
||||||
# Vocabulary size for TexTeller
|
# Vocabulary size for TexTeller
|
||||||
VOCAB_SIZE = 15000
|
VOCAB_SIZE = 15000
|
||||||
@@ -20,4 +20,4 @@ MIN_RESIZE_RATIO = 0.75
|
|||||||
|
|
||||||
# Minimum height and width for input image for TexTeller
|
# Minimum height and width for input image for TexTeller
|
||||||
MIN_HEIGHT = 12
|
MIN_HEIGHT = 12
|
||||||
MIN_WIDTH = 30
|
MIN_WIDTH = 30
|
||||||
@@ -1,30 +1,24 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ...globals import (
|
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
||||||
VOCAB_SIZE,
|
|
||||||
FIXED_IMG_SIZE,
|
|
||||||
IMG_CHANNELS,
|
|
||||||
MAX_TOKEN_SIZE
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
||||||
RobertaTokenizerFast,
|
|
||||||
VisionEncoderDecoderModel,
|
|
||||||
VisionEncoderDecoderConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TexTeller(VisionEncoderDecoderModel):
|
class TexTeller(VisionEncoderDecoderModel):
|
||||||
REPO_NAME = 'OleehyO/TexTeller'
|
REPO_NAME = 'OleehyO/TexTeller'
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
|
config = VisionEncoderDecoderConfig.from_pretrained(
|
||||||
config.encoder.image_size = FIXED_IMG_SIZE
|
Path(__file__).resolve().parent / "config.json"
|
||||||
config.encoder.num_channels = IMG_CHANNELS
|
)
|
||||||
config.decoder.vocab_size = VOCAB_SIZE
|
config.encoder.image_size = FIXED_IMG_SIZE
|
||||||
|
config.encoder.num_channels = IMG_CHANNELS
|
||||||
|
config.decoder.vocab_size = VOCAB_SIZE
|
||||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||||
if model_path is None or model_path == 'default':
|
if model_path is None or model_path == 'default':
|
||||||
@@ -32,8 +26,12 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||||
else:
|
else:
|
||||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||||
|
|
||||||
use_gpu = True if onnx_provider == 'cuda' else False
|
use_gpu = True if onnx_provider == 'cuda' else False
|
||||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
return ORTModelForVision2Seq.from_pretrained(
|
||||||
|
cls.REPO_NAME,
|
||||||
|
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
||||||
|
)
|
||||||
model_path = Path(model_path).resolve()
|
model_path = Path(model_path).resolve()
|
||||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||||
|
|
||||||
BIN
texteller/models/ocr_model/train/augraphy_cache/image_0.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_1.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_10.png
Normal file
|
After Width: | Height: | Size: 4.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_11.png
Normal file
|
After Width: | Height: | Size: 8.5 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_12.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_13.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_14.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_15.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_16.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_17.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_18.png
Normal file
|
After Width: | Height: | Size: 5.6 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_19.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_2.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_20.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_21.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_22.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_23.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_24.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_25.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_26.png
Normal file
|
After Width: | Height: | Size: 9.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_27.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_28.png
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_29.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_4.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_6.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_7.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_8.png
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
texteller/models/ocr_model/train/augraphy_cache/image_9.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 8.7 KiB After Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 2.8 KiB After Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB After Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 3.7 KiB After Width: | Height: | Size: 3.7 KiB |
|
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB After Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB After Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 4.9 KiB After Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB After Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 1.8 KiB After Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.7 KiB After Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 4.8 KiB After Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB After Width: | Height: | Size: 5.2 KiB |
@@ -0,0 +1,35 @@
|
|||||||
|
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||||
|
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||||
|
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||||
|
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||||
|
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||||
|
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||||
|
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||||
|
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||||
|
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||||
|
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||||
|
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||||
|
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||||
|
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||||
|
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||||
|
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||||
|
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||||
|
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||||
|
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||||
|
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||||
|
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||||
|
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||||
|
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||||
|
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||||
|
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||||
|
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||||
|
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||||
|
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||||
|
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||||
|
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||||
|
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||||
@@ -5,18 +5,24 @@ from pathlib import Path
|
|||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
Seq2SeqTrainer,
|
Seq2SeqTrainer,
|
||||||
Seq2SeqTrainingArguments,
|
Seq2SeqTrainingArguments,
|
||||||
GenerationConfig
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .training_args import CONFIG
|
from .training_args import CONFIG
|
||||||
from ..model.TexTeller import TexTeller
|
from ..model.TexTeller import TexTeller
|
||||||
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
|
from ..utils.functional import (
|
||||||
|
tokenize_fn,
|
||||||
|
collate_fn,
|
||||||
|
img_train_transform,
|
||||||
|
img_inf_transform,
|
||||||
|
filter_fn,
|
||||||
|
)
|
||||||
from ..utils.metrics import bleu_metric
|
from ..utils.metrics import bleu_metric
|
||||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||||
|
|
||||||
|
|
||||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||||
@@ -24,11 +30,9 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
training_args,
|
training_args,
|
||||||
|
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
tokenizer=tokenizer,
|
|
||||||
data_collator=collate_fn_with_tokenizer,
|
data_collator=collate_fn_with_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,43 +56,44 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model,
|
model,
|
||||||
seq2seq_config,
|
seq2seq_config,
|
||||||
|
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_res = trainer.evaluate()
|
eval_res = trainer.evaluate()
|
||||||
print(eval_res)
|
print(eval_res)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
||||||
|
)
|
||||||
dataset = dataset.shuffle(seed=42)
|
dataset = dataset.shuffle(seed=42)
|
||||||
dataset = dataset.flatten_indices()
|
dataset = dataset.flatten_indices()
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
tokenizer = TexTeller.get_tokenizer()
|
||||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||||
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||||
filter_fn_with_tokenizer,
|
|
||||||
num_proc=8
|
|
||||||
)
|
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
tokenized_dataset = dataset.map(
|
||||||
|
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||||
|
)
|
||||||
|
|
||||||
# Split dataset into train and eval, ratio 9:1
|
# Split dataset into train and eval, ratio 9:1
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
# Train from scratch
|
# Train from scratch
|
||||||
@@ -96,14 +101,14 @@ if __name__ == '__main__':
|
|||||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||||
|
|
||||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||||
#+e.g.
|
# +e.g.
|
||||||
#+model = TexTeller.from_pretrained(
|
# +model = TexTeller.from_pretrained(
|
||||||
#+ '/path/to/your/model_checkpoint'
|
# + '/path/to/your/model_checkpoint'
|
||||||
#+)
|
# +)
|
||||||
|
|
||||||
enable_train = True
|
enable_train = True
|
||||||
enable_evaluate = False
|
enable_evaluate = False
|
||||||
if enable_train:
|
if enable_train:
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
if enable_evaluate and len(eval_dataset) > 0:
|
if enable_evaluate and len(eval_dataset) > 0:
|
||||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||||
31
texteller/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
CONFIG = {
|
||||||
|
"seed": 42, # Random seed for reproducibility
|
||||||
|
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||||
|
"learning_rate": 5e-5, # Learning rate
|
||||||
|
"num_train_epochs": 10, # Total number of training epochs
|
||||||
|
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||||
|
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||||
|
"output_dir": "train_result", # Output directory
|
||||||
|
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||||
|
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||||
|
"save_strategy": "steps", # Strategy to save checkpoints
|
||||||
|
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||||
|
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||||
|
"logging_strategy": "steps", # Log every certain number of steps
|
||||||
|
"logging_steps": 500, # Number of steps between each log
|
||||||
|
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||||
|
"optim": "adamw_torch", # Optimizer
|
||||||
|
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||||
|
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||||
|
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||||
|
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||||
|
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||||
|
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||||
|
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||||
|
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||||
|
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||||
|
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||||
|
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||||
|
"eval_steps": 500, # If evaluation_strategy="step"
|
||||||
|
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||||
|
}
|
||||||
@@ -26,7 +26,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
|||||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||||
|
|
||||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
batch = clm_collator(samples)
|
batch = clm_collator(samples)
|
||||||
batch['pixel_values'] = pixel_values
|
batch['pixel_values'] = pixel_values
|
||||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||||
@@ -54,6 +54,7 @@ def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|||||||
|
|
||||||
def filter_fn(sample, tokenizer=None) -> bool:
|
def filter_fn(sample, tokenizer=None) -> bool:
|
||||||
return (
|
return (
|
||||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
sample['image'].height > MIN_HEIGHT
|
||||||
|
and sample['image'].width > MIN_WIDTH
|
||||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||||
)
|
)
|
||||||
@@ -12,7 +12,7 @@ def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
|||||||
continue
|
continue
|
||||||
if image.dtype == np.uint16:
|
if image.dtype == np.uint16:
|
||||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||||
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
|
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||||
|
|
||||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||||
if channels == 4:
|
if channels == 4:
|
||||||