67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
from pathlib import Path
|
|
|
|
import wget
|
|
from onnxruntime import InferenceSession
|
|
from transformers import RobertaTokenizerFast
|
|
|
|
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
|
|
from texteller.globals import Globals
|
|
from texteller.logger import get_logger
|
|
from texteller.models import TexTeller
|
|
from texteller.paddleocr import predict_det, predict_rec
|
|
from texteller.paddleocr.utility import parse_args
|
|
from texteller.utils import cuda_available, mkdir, resolve_path
|
|
from texteller.types import TexTellerModel
|
|
|
|
_logger = get_logger(__name__)
|
|
|
|
|
|
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
|
|
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
|
|
|
|
|
|
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
|
|
return TexTeller.get_tokenizer(tokenizer_dir)
|
|
|
|
|
|
def load_latexdet_model() -> InferenceSession:
|
|
fpath = _maybe_download(LATEX_DET_MODEL_URL)
|
|
return InferenceSession(
|
|
resolve_path(fpath),
|
|
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
|
|
)
|
|
|
|
|
|
def load_textrec_model() -> predict_rec.TextRecognizer:
|
|
fpath = _maybe_download(TEXT_REC_MODEL_URL)
|
|
paddleocr_args = parse_args()
|
|
paddleocr_args.use_onnx = True
|
|
paddleocr_args.rec_model_dir = resolve_path(fpath)
|
|
paddleocr_args.use_gpu = cuda_available()
|
|
predictor = predict_rec.TextRecognizer(paddleocr_args)
|
|
return predictor
|
|
|
|
|
|
def load_textdet_model() -> predict_det.TextDetector:
|
|
fpath = _maybe_download(TEXT_DET_MODEL_URL)
|
|
paddleocr_args = parse_args()
|
|
paddleocr_args.use_onnx = True
|
|
paddleocr_args.det_model_dir = resolve_path(fpath)
|
|
paddleocr_args.use_gpu = cuda_available()
|
|
predictor = predict_det.TextDetector(paddleocr_args)
|
|
return predictor
|
|
|
|
|
|
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
|
|
if dirpath is None:
|
|
dirpath = Globals().cache_dir
|
|
mkdir(dirpath)
|
|
|
|
fname = Path(url).name
|
|
fpath = Path(dirpath) / fname
|
|
if not fpath.exists() or force:
|
|
_logger.info(f"Downloading {fname} from {url} to {fpath}")
|
|
wget.download(url, resolve_path(fpath))
|
|
|
|
return fpath
|