Merge pull request #78 from OleehyO/pre_release

Change to better import dependency
This commit is contained in:
OleehyO
2024-08-07 12:43:15 +08:00
committed by GitHub

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from optimum.onnxruntime import ORTModelForVision2Seq
from ...globals import (
VOCAB_SIZE,
@@ -32,6 +31,7 @@ class TexTeller(VisionEncoderDecoderModel):
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
model_path = Path(model_path).resolve()