[refactor] Init
This commit is contained in:
66
texteller/api/load.py
Normal file
66
texteller/api/load.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
|
||||
import wget
|
||||
from onnxruntime import InferenceSession
|
||||
from transformers import RobertaTokenizerFast
|
||||
|
||||
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
|
||||
from texteller.globals import Globals
|
||||
from texteller.logger import get_logger
|
||||
from texteller.models import TexTeller
|
||||
from texteller.paddleocr import predict_det, predict_rec
|
||||
from texteller.paddleocr.utility import parse_args
|
||||
from texteller.utils import cuda_available, mkdir, resolve_path
|
||||
from texteller.types import TexTellerModel
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
|
||||
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
|
||||
return TexTeller.get_tokenizer(tokenizer_dir)
|
||||
|
||||
|
||||
def load_latexdet_model() -> InferenceSession:
|
||||
fpath = _maybe_download(LATEX_DET_MODEL_URL)
|
||||
return InferenceSession(
|
||||
resolve_path(fpath),
|
||||
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
|
||||
def load_textrec_model() -> predict_rec.TextRecognizer:
|
||||
fpath = _maybe_download(TEXT_REC_MODEL_URL)
|
||||
paddleocr_args = parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.rec_model_dir = resolve_path(fpath)
|
||||
paddleocr_args.use_gpu = cuda_available()
|
||||
predictor = predict_rec.TextRecognizer(paddleocr_args)
|
||||
return predictor
|
||||
|
||||
|
||||
def load_textdet_model() -> predict_det.TextDetector:
|
||||
fpath = _maybe_download(TEXT_DET_MODEL_URL)
|
||||
paddleocr_args = parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.det_model_dir = resolve_path(fpath)
|
||||
paddleocr_args.use_gpu = cuda_available()
|
||||
predictor = predict_det.TextDetector(paddleocr_args)
|
||||
return predictor
|
||||
|
||||
|
||||
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
|
||||
if dirpath is None:
|
||||
dirpath = Globals().cache_dir
|
||||
mkdir(dirpath)
|
||||
|
||||
fname = Path(url).name
|
||||
fpath = Path(dirpath) / fname
|
||||
if not fpath.exists() or force:
|
||||
_logger.info(f"Downloading {fname} from {url} to {fpath}")
|
||||
wget.download(url, resolve_path(fpath))
|
||||
|
||||
return fpath
|
||||
Reference in New Issue
Block a user