49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
|
||
|
|
|
||
|
|
from texteller.constants import (
|
||
|
|
FIXED_IMG_SIZE,
|
||
|
|
IMG_CHANNELS,
|
||
|
|
MAX_TOKEN_SIZE,
|
||
|
|
VOCAB_SIZE,
|
||
|
|
)
|
||
|
|
from texteller.globals import Globals
|
||
|
|
from texteller.types import TexTellerModel
|
||
|
|
from texteller.utils import cuda_available
|
||
|
|
|
||
|
|
|
||
|
|
class TexTeller(VisionEncoderDecoderModel):
|
||
|
|
def __init__(self):
|
||
|
|
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
|
||
|
|
config.encoder.image_size = FIXED_IMG_SIZE
|
||
|
|
config.encoder.num_channels = IMG_CHANNELS
|
||
|
|
config.decoder.vocab_size = VOCAB_SIZE
|
||
|
|
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||
|
|
|
||
|
|
super().__init__(config=config)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
|
||
|
|
if model_dir is None or model_dir == Globals().repo_name:
|
||
|
|
if not use_onnx:
|
||
|
|
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
|
||
|
|
else:
|
||
|
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||
|
|
|
||
|
|
return ORTModelForVision2Seq.from_pretrained(
|
||
|
|
Globals().repo_name,
|
||
|
|
provider="CUDAExecutionProvider"
|
||
|
|
if cuda_available()
|
||
|
|
else "CPUExecutionProvider",
|
||
|
|
)
|
||
|
|
model_dir = Path(model_dir).resolve()
|
||
|
|
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
|
||
|
|
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
|
||
|
|
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
|
||
|
|
tokenizer_dir = Path(tokenizer_dir).resolve()
|
||
|
|
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))
|