2024-02-11 08:06:50 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
from models.globals import (
|
|
|
|
|
VOCAB_SIZE,
|
|
|
|
|
FIXED_IMG_SIZE,
|
|
|
|
|
IMG_CHANNELS,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
|
ViTConfig,
|
|
|
|
|
ViTModel,
|
|
|
|
|
TrOCRConfig,
|
|
|
|
|
TrOCRForCausalLM,
|
|
|
|
|
RobertaTokenizerFast,
|
|
|
|
|
VisionEncoderDecoderModel,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TexTeller(VisionEncoderDecoderModel):
|
2024-03-25 06:54:22 +00:00
|
|
|
REPO_NAME = '/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-588000'
|
2024-02-11 08:06:50 +00:00
|
|
|
def __init__(self, decoder_path=None, tokenizer_path=None):
|
|
|
|
|
encoder = ViTModel(ViTConfig(
|
|
|
|
|
image_size=FIXED_IMG_SIZE,
|
|
|
|
|
num_channels=IMG_CHANNELS
|
|
|
|
|
))
|
|
|
|
|
decoder = TrOCRForCausalLM(TrOCRConfig(
|
|
|
|
|
vocab_size=VOCAB_SIZE,
|
|
|
|
|
))
|
|
|
|
|
super().__init__(encoder=encoder, decoder=decoder)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_pretrained(cls, model_path: str = None):
|
2024-02-12 11:40:51 +00:00
|
|
|
if model_path is None or model_path == 'default':
|
2024-02-11 08:06:50 +00:00
|
|
|
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
|
|
|
|
model_path = Path(model_path).resolve()
|
|
|
|
|
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
2024-02-12 11:40:51 +00:00
|
|
|
if tokenizer_path is None or tokenizer_path == 'default':
|
2024-02-11 08:06:50 +00:00
|
|
|
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
|
|
|
|
tokenizer_path = Path(tokenizer_path).resolve()
|
|
|
|
|
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|