Eliminated dependency on paddleocr

Change to trocr
This commit is contained in:
三洋三洋
2024-05-27 16:45:33 +00:00
parent edef073812
commit 714fef4def
2 changed files with 13 additions and 15 deletions

View File

@@ -4,29 +4,26 @@ from ...globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE
)
from transformers import (
ViTConfig,
ViTModel,
TrOCRConfig,
TrOCRForCausalLM,
RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderConfig
)
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self, decoder_path=None, tokenizer_path=None):
encoder = ViTModel(ViTConfig(
image_size=FIXED_IMG_SIZE,
num_channels=IMG_CHANNELS
))
decoder = TrOCRForCausalLM(TrOCRConfig(
vocab_size=VOCAB_SIZE,
))
super().__init__(encoder=encoder, decoder=decoder)
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small')
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size=VOCAB_SIZE
config.decoder.max_position_embeddings=MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_path: str = None):