diff --git a/src/globals.py b/src/globals.py index ffc2f91..b8f9530 100644 --- a/src/globals.py +++ b/src/globals.py @@ -15,6 +15,15 @@ MAX_WIDTH = 1280 # ocr模型所用数据集中,图片所用的Density渲染值(实际上图片用的渲染Density不是80,而是100) TEXIFY_INPUT_DENSITY = 80 +# ocr模型的tokenizer中的词典数量 +VOCAB_SIZE = 10000 + +# ocr模型训练时,输入图片所固定的大小 +OCR_IMG_SIZE = 448 + +# ocr模型输入图片的通道数 +OCR_IMG_CHANNELS = 1 # 灰度图 + # ============================================================================= # diff --git a/src/models/ocr_model/README.md b/src/models/ocr_model/README.md new file mode 100644 index 0000000..4625a80 --- /dev/null +++ b/src/models/ocr_model/README.md @@ -0,0 +1,6 @@ +* Encoder-Decoder架构 + +* Encoder使用Deit_{BASE} + +* Decoder使用RoBERTa_{LARGE} + * Decoder的tokenizer也使用RoBERTa_{LARGE}的 \ No newline at end of file diff --git a/src/models/ocr_model/inference.py b/src/models/ocr_model/inference.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py new file mode 100644 index 0000000..c255ef0 --- /dev/null +++ b/src/models/ocr_model/model/TexTeller.py @@ -0,0 +1,63 @@ +from ....globals import ( + VOCAB_SIZE, + OCR_IMG_SIZE, + OCR_IMG_CHANNELS +) + +from typing import ( + Tuple +) + +from transformers import ( + DeiTConfig, + DeiTModel, + + RobertaConfig, + RobertaModel, + RobertaTokenizerFast, + + VisionEncoderDecoderConfig, + VisionEncoderDecoderModel +) + + +class TexTeller: + def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None): + self.tokenizer = self.get_tokenizer(tokenizer_path) + + assert not (encoder_path is None and decoder_path is not None) + assert not (encoder_path is not None and decoder_path is None) + + if encoder_path is None: + encoder_config = DeiTConfig( + img_size=OCR_IMG_SIZE, + num_channels=OCR_IMG_CHANNELS + ) + + decoder_config = RobertaConfig( + vocab_size=VOCAB_SIZE, + is_decoder=True + ) + + model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs( + encoder_config, + decoder_config + ) + self.model = VisionEncoderDecoderModel(model_config) + + else: + self.model = VisionEncoderDecoderModel.from_pretrained( + encoder_path, + decoder_path + ) + + + ... + + @classmethod + def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast: + if tokenizer_path is None: + return RobertaTokenizerFast() + else: + return RobertaTokenizerFast.from_pretrained(tokenizer_path) + \ No newline at end of file diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/ocr_model/utils/get.py b/src/models/ocr_model/utils/get.py new file mode 100644 index 0000000..9f4f00a --- /dev/null +++ b/src/models/ocr_model/utils/get.py @@ -0,0 +1,28 @@ +from ....globals import VOCAB_SIZE +from typing import ( + Tuple +) + +from transformers import ( + RobertaConfig, + RobertaModel, + RobertaTokenizerFast +) + + +def get_encoder(): + ... + + +def get_tokenizer() -> RobertaTokenizerFast: + ... + + +def get_decoder() -> RobertaModel: + configuration = RobertaConfig( + vocab_size=VOCAB_SIZE, + is_decoder=True + ) + model = RobertaModel(configuration) + return model + diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py new file mode 100644 index 0000000..e69de29