写好了ocr_model训练脚本的大致框架

This commit is contained in:
三洋三洋
2024-01-23 04:23:08 +00:00
parent 703ac7441c
commit 9d27ee0585
7 changed files with 106 additions and 0 deletions

View File

@@ -15,6 +15,15 @@ MAX_WIDTH = 1280
# ocr模型所用数据集中图片所用的Density渲染值实际上图片用的渲染Density不是80而是100
TEXIFY_INPUT_DENSITY = 80
# ocr模型的tokenizer中的词典数量
VOCAB_SIZE = 10000
# ocr模型训练时输入图片所固定的大小
OCR_IMG_SIZE = 448
# ocr模型输入图片的通道数
OCR_IMG_CHANNELS = 1 # 灰度图
# ============================================================================= #

View File

@@ -0,0 +1,6 @@
* Encoder-Decoder架构
* Encoder使用Deit_{BASE}
* Decoder使用RoBERTa_{LARGE}
* Decoder的tokenizer也使用RoBERTa_{LARGE}的

View File

View File

@@ -0,0 +1,63 @@
from ....globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS
)
from typing import (
Tuple
)
from transformers import (
DeiTConfig,
DeiTModel,
RobertaConfig,
RobertaModel,
RobertaTokenizerFast,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel
)
class TexTeller:
def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None):
self.tokenizer = self.get_tokenizer(tokenizer_path)
assert not (encoder_path is None and decoder_path is not None)
assert not (encoder_path is not None and decoder_path is None)
if encoder_path is None:
encoder_config = DeiTConfig(
img_size=OCR_IMG_SIZE,
num_channels=OCR_IMG_CHANNELS
)
decoder_config = RobertaConfig(
vocab_size=VOCAB_SIZE,
is_decoder=True
)
model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config,
decoder_config
)
self.model = VisionEncoderDecoderModel(model_config)
else:
self.model = VisionEncoderDecoderModel.from_pretrained(
encoder_path,
decoder_path
)
...
@classmethod
def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None:
return RobertaTokenizerFast()
else:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)

View File

View File

@@ -0,0 +1,28 @@
from ....globals import VOCAB_SIZE
from typing import (
Tuple
)
from transformers import (
RobertaConfig,
RobertaModel,
RobertaTokenizerFast
)
def get_encoder():
...
def get_tokenizer() -> RobertaTokenizerFast:
...
def get_decoder() -> RobertaModel:
configuration = RobertaConfig(
vocab_size=VOCAB_SIZE,
is_decoder=True
)
model = RobertaModel(configuration)
return model

View File