2024-02-11 08:06:50 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
2025-02-28 19:56:49 +08:00
|
|
|
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2025-02-28 19:56:49 +08:00
|
|
|
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TexTeller(VisionEncoderDecoderModel):
|
2024-03-25 11:23:54 +00:00
|
|
|
REPO_NAME = 'OleehyO/TexTeller'
|
2025-02-28 19:56:49 +08:00
|
|
|
|
2024-05-27 16:45:33 +00:00
|
|
|
def __init__(self):
|
2025-02-28 19:56:49 +08:00
|
|
|
config = VisionEncoderDecoderConfig.from_pretrained(
|
|
|
|
|
Path(__file__).resolve().parent / "config.json"
|
|
|
|
|
)
|
|
|
|
|
config.encoder.image_size = FIXED_IMG_SIZE
|
|
|
|
|
config.encoder.num_channels = IMG_CHANNELS
|
|
|
|
|
config.decoder.vocab_size = VOCAB_SIZE
|
2024-06-22 22:08:08 +08:00
|
|
|
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
2024-05-27 16:45:33 +00:00
|
|
|
|
|
|
|
|
super().__init__(config=config)
|
2025-02-28 19:56:49 +08:00
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
@classmethod
|
2024-06-22 21:51:51 +08:00
|
|
|
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
2024-02-12 11:40:51 +00:00
|
|
|
if model_path is None or model_path == 'default':
|
2024-06-22 21:51:51 +08:00
|
|
|
if not use_onnx:
|
|
|
|
|
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
|
|
|
|
else:
|
2024-08-07 01:19:26 +08:00
|
|
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
2025-02-28 19:56:49 +08:00
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
use_gpu = True if onnx_provider == 'cuda' else False
|
2025-02-28 19:56:49 +08:00
|
|
|
return ORTModelForVision2Seq.from_pretrained(
|
|
|
|
|
cls.REPO_NAME,
|
|
|
|
|
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
|
|
|
|
)
|
2024-02-11 08:06:50 +00:00
|
|
|
model_path = Path(model_path).resolve()
|
|
|
|
|
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
2024-02-12 11:40:51 +00:00
|
|
|
if tokenizer_path is None or tokenizer_path == 'default':
|
2024-02-11 08:06:50 +00:00
|
|
|
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
|
|
|
|
tokenizer_path = Path(tokenizer_path).resolve()
|
|
|
|
|
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|