Change to better import dependency
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
from ...globals import (
|
||||
VOCAB_SIZE,
|
||||
@@ -32,6 +31,7 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
||||
model_path = Path(model_path).resolve()
|
||||
|
||||
Reference in New Issue
Block a user