Change the model configuration to trocr

This commit is contained in:
三洋三洋
2024-05-28 04:20:07 +00:00
parent 9b11689f22
commit 89aa396cbb
2 changed files with 162 additions and 6 deletions

View File

@@ -10,18 +10,18 @@ from ...globals import (
from transformers import (
RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderConfig
VisionEncoderDecoderConfig,
)
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small')
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size=VOCAB_SIZE
config.decoder.max_position_embeddings=MAX_TOKEN_SIZE
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json")
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)