Change to better import dependency
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
|
||||||
|
|
||||||
from ...globals import (
|
from ...globals import (
|
||||||
VOCAB_SIZE,
|
VOCAB_SIZE,
|
||||||
@@ -32,6 +31,7 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
if not use_onnx:
|
if not use_onnx:
|
||||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||||
else:
|
else:
|
||||||
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||||
use_gpu = True if onnx_provider == 'cuda' else False
|
use_gpu = True if onnx_provider == 'cuda' else False
|
||||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
||||||
model_path = Path(model_path).resolve()
|
model_path = Path(model_path).resolve()
|
||||||
|
|||||||
Reference in New Issue
Block a user