Files

49 lines
1.9 KiB
Python
Raw Permalink Normal View History

2025-04-16 14:23:02 +00:00
from pathlib import Path
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
from texteller.constants import (
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE,
VOCAB_SIZE,
)
from texteller.globals import Globals
from texteller.types import TexTellerModel
from texteller.utils import cuda_available
class TexTeller(VisionEncoderDecoderModel):
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
if model_dir is None or model_dir == Globals().repo_name:
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
return ORTModelForVision2Seq.from_pretrained(
Globals().repo_name,
provider="CUDAExecutionProvider"
if cuda_available()
else "CPUExecutionProvider",
)
model_dir = Path(model_dir).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
@classmethod
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
tokenizer_dir = Path(tokenizer_dir).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))