From 9d27ee0585dff547ef60b09c91d19192464f9379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Tue, 23 Jan 2024 04:23:08 +0000 Subject: [PATCH] =?UTF-8?q?=E5=86=99=E5=A5=BD=E4=BA=86ocr=5Fmodel=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=84=9A=E6=9C=AC=E7=9A=84=E5=A4=A7=E8=87=B4=E6=A1=86?= =?UTF-8?q?=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/globals.py | 9 ++++ src/models/ocr_model/README.md | 6 +++ src/models/ocr_model/inference.py | 0 src/models/ocr_model/model/TexTeller.py | 63 ++++++++++++++++++++++++ src/models/ocr_model/train/train.py | 0 src/models/ocr_model/utils/get.py | 28 +++++++++++ src/models/ocr_model/utils/transforms.py | 0 7 files changed, 106 insertions(+) create mode 100644 src/models/ocr_model/README.md create mode 100644 src/models/ocr_model/inference.py create mode 100644 src/models/ocr_model/model/TexTeller.py create mode 100644 src/models/ocr_model/train/train.py create mode 100644 src/models/ocr_model/utils/get.py create mode 100644 src/models/ocr_model/utils/transforms.py diff --git a/src/globals.py b/src/globals.py index ffc2f91..b8f9530 100644 --- a/src/globals.py +++ b/src/globals.py @@ -15,6 +15,15 @@ MAX_WIDTH = 1280 # ocr模型所用数据集中,图片所用的Density渲染值(实际上图片用的渲染Density不是80,而是100) TEXIFY_INPUT_DENSITY = 80 +# ocr模型的tokenizer中的词典数量 +VOCAB_SIZE = 10000 + +# ocr模型训练时,输入图片所固定的大小 +OCR_IMG_SIZE = 448 + +# ocr模型输入图片的通道数 +OCR_IMG_CHANNELS = 1 # 灰度图 + # ============================================================================= # diff --git a/src/models/ocr_model/README.md b/src/models/ocr_model/README.md new file mode 100644 index 0000000..4625a80 --- /dev/null +++ b/src/models/ocr_model/README.md @@ -0,0 +1,6 @@ +* Encoder-Decoder架构 + +* Encoder使用Deit_{BASE} + +* Decoder使用RoBERTa_{LARGE} + * Decoder的tokenizer也使用RoBERTa_{LARGE}的 \ No newline at end of file diff --git a/src/models/ocr_model/inference.py b/src/models/ocr_model/inference.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py new file mode 100644 index 0000000..c255ef0 --- /dev/null +++ b/src/models/ocr_model/model/TexTeller.py @@ -0,0 +1,63 @@ +from ....globals import ( + VOCAB_SIZE, + OCR_IMG_SIZE, + OCR_IMG_CHANNELS +) + +from typing import ( + Tuple +) + +from transformers import ( + DeiTConfig, + DeiTModel, + + RobertaConfig, + RobertaModel, + RobertaTokenizerFast, + + VisionEncoderDecoderConfig, + VisionEncoderDecoderModel +) + + +class TexTeller: + def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None): + self.tokenizer = self.get_tokenizer(tokenizer_path) + + assert not (encoder_path is None and decoder_path is not None) + assert not (encoder_path is not None and decoder_path is None) + + if encoder_path is None: + encoder_config = DeiTConfig( + img_size=OCR_IMG_SIZE, + num_channels=OCR_IMG_CHANNELS + ) + + decoder_config = RobertaConfig( + vocab_size=VOCAB_SIZE, + is_decoder=True + ) + + model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs( + encoder_config, + decoder_config + ) + self.model = VisionEncoderDecoderModel(model_config) + + else: + self.model = VisionEncoderDecoderModel.from_pretrained( + encoder_path, + decoder_path + ) + + + ... + + @classmethod + def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast: + if tokenizer_path is None: + return RobertaTokenizerFast() + else: + return RobertaTokenizerFast.from_pretrained(tokenizer_path) + \ No newline at end of file diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/ocr_model/utils/get.py b/src/models/ocr_model/utils/get.py new file mode 100644 index 0000000..9f4f00a --- /dev/null +++ b/src/models/ocr_model/utils/get.py @@ -0,0 +1,28 @@ +from ....globals import VOCAB_SIZE +from typing import ( + Tuple +) + +from transformers import ( + RobertaConfig, + RobertaModel, + RobertaTokenizerFast +) + + +def get_encoder(): + ... + + +def get_tokenizer() -> RobertaTokenizerFast: + ... + + +def get_decoder() -> RobertaModel: + configuration = RobertaConfig( + vocab_size=VOCAB_SIZE, + is_decoder=True + ) + model = RobertaModel(configuration) + return model + diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py new file mode 100644 index 0000000..e69de29