From 5c9cff21258d55f5277436d5a705d22cb8bd7e2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Mon, 27 May 2024 16:45:33 +0000 Subject: [PATCH] Eliminated dependency on paddleocr Change to trocr --- requirements.txt | 5 +++-- src/models/ocr_model/model/TexTeller.py | 23 ++++++++++------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/requirements.txt b/requirements.txt index c4861aa..d8811fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,9 +9,10 @@ nltk python-multipart augraphy -onnxruntime-gpu streamlit==1.30 streamlit-paste-button -paddleocr +shapely +pyclipper +onnxruntime-gpu diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 93bd03d..08ca257 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -4,29 +4,26 @@ from ...globals import ( VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, + MAX_TOKEN_SIZE ) from transformers import ( - ViTConfig, - ViTModel, - TrOCRConfig, - TrOCRForCausalLM, RobertaTokenizerFast, VisionEncoderDecoderModel, + VisionEncoderDecoderConfig ) class TexTeller(VisionEncoderDecoderModel): REPO_NAME = 'OleehyO/TexTeller' - def __init__(self, decoder_path=None, tokenizer_path=None): - encoder = ViTModel(ViTConfig( - image_size=FIXED_IMG_SIZE, - num_channels=IMG_CHANNELS - )) - decoder = TrOCRForCausalLM(TrOCRConfig( - vocab_size=VOCAB_SIZE, - )) - super().__init__(encoder=encoder, decoder=decoder) + def __init__(self): + config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small') + config.encoder.image_size = FIXED_IMG_SIZE + config.encoder.num_channels = IMG_CHANNELS + config.decoder.vocab_size=VOCAB_SIZE + config.decoder.max_position_embeddings=MAX_TOKEN_SIZE + + super().__init__(config=config) @classmethod def from_pretrained(cls, model_path: str = None):