Files
TexTeller/src/models/ocr_model/model/TexTeller.py

41 lines
1.4 KiB
Python
Raw Normal View History

2024-02-11 08:06:50 +00:00
from pathlib import Path
2024-04-16 13:56:56 +00:00
from ...globals import (
2024-02-11 08:06:50 +00:00
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE
2024-02-11 08:06:50 +00:00
)
from transformers import (
RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderConfig,
2024-02-11 08:06:50 +00:00
)
class TexTeller(VisionEncoderDecoderModel):
2024-03-25 11:23:54 +00:00
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
2024-02-11 08:06:50 +00:00
@classmethod
def from_pretrained(cls, model_path: str = None):
2024-02-12 11:40:51 +00:00
if model_path is None or model_path == 'default':
2024-02-11 08:06:50 +00:00
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
2024-02-12 11:40:51 +00:00
if tokenizer_path is None or tokenizer_path == 'default':
2024-02-11 08:06:50 +00:00
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))