Change to better import dependency

This commit is contained in:
三洋三洋
2024-08-07 01:19:26 +08:00
parent bbc8ecf88b
commit e1046ba3fa

View File

@@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from optimum.onnxruntime import ORTModelForVision2Seq
from ...globals import ( from ...globals import (
VOCAB_SIZE, VOCAB_SIZE,
@@ -32,6 +31,7 @@ class TexTeller(VisionEncoderDecoderModel):
if not use_onnx: if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME) return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else: else:
from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider") return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
model_path = Path(model_path).resolve() model_path = Path(model_path).resolve()