Support onnx runtime
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
from ...globals import (
|
||||
VOCAB_SIZE,
|
||||
@@ -10,25 +11,29 @@ from ...globals import (
|
||||
from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
VisionEncoderDecoderModel,
|
||||
VisionEncoderDecoderConfig,
|
||||
VisionEncoderDecoderConfig
|
||||
)
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small')
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size=VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings=MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None):
|
||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||
if model_path is None or model_path == 'default':
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user