写好了ocr_model训练脚本的大致框架
This commit is contained in:
@@ -15,6 +15,15 @@ MAX_WIDTH = 1280
|
|||||||
# ocr模型所用数据集中,图片所用的Density渲染值(实际上图片用的渲染Density不是80,而是100)
|
# ocr模型所用数据集中,图片所用的Density渲染值(实际上图片用的渲染Density不是80,而是100)
|
||||||
TEXIFY_INPUT_DENSITY = 80
|
TEXIFY_INPUT_DENSITY = 80
|
||||||
|
|
||||||
|
# ocr模型的tokenizer中的词典数量
|
||||||
|
VOCAB_SIZE = 10000
|
||||||
|
|
||||||
|
# ocr模型训练时,输入图片所固定的大小
|
||||||
|
OCR_IMG_SIZE = 448
|
||||||
|
|
||||||
|
# ocr模型输入图片的通道数
|
||||||
|
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||||
|
|
||||||
# ============================================================================= #
|
# ============================================================================= #
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
6
src/models/ocr_model/README.md
Normal file
6
src/models/ocr_model/README.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
* Encoder-Decoder架构
|
||||||
|
|
||||||
|
* Encoder使用Deit_{BASE}
|
||||||
|
|
||||||
|
* Decoder使用RoBERTa_{LARGE}
|
||||||
|
* Decoder的tokenizer也使用RoBERTa_{LARGE}的
|
||||||
0
src/models/ocr_model/inference.py
Normal file
0
src/models/ocr_model/inference.py
Normal file
63
src/models/ocr_model/model/TexTeller.py
Normal file
63
src/models/ocr_model/model/TexTeller.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from ....globals import (
|
||||||
|
VOCAB_SIZE,
|
||||||
|
OCR_IMG_SIZE,
|
||||||
|
OCR_IMG_CHANNELS
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
DeiTConfig,
|
||||||
|
DeiTModel,
|
||||||
|
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaModel,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
|
||||||
|
VisionEncoderDecoderConfig,
|
||||||
|
VisionEncoderDecoderModel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TexTeller:
|
||||||
|
def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None):
|
||||||
|
self.tokenizer = self.get_tokenizer(tokenizer_path)
|
||||||
|
|
||||||
|
assert not (encoder_path is None and decoder_path is not None)
|
||||||
|
assert not (encoder_path is not None and decoder_path is None)
|
||||||
|
|
||||||
|
if encoder_path is None:
|
||||||
|
encoder_config = DeiTConfig(
|
||||||
|
img_size=OCR_IMG_SIZE,
|
||||||
|
num_channels=OCR_IMG_CHANNELS
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_config = RobertaConfig(
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
is_decoder=True
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
|
||||||
|
encoder_config,
|
||||||
|
decoder_config
|
||||||
|
)
|
||||||
|
self.model = VisionEncoderDecoderModel(model_config)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.model = VisionEncoderDecoderModel.from_pretrained(
|
||||||
|
encoder_path,
|
||||||
|
decoder_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||||
|
if tokenizer_path is None:
|
||||||
|
return RobertaTokenizerFast()
|
||||||
|
else:
|
||||||
|
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
0
src/models/ocr_model/train/train.py
Normal file
0
src/models/ocr_model/train/train.py
Normal file
28
src/models/ocr_model/utils/get.py
Normal file
28
src/models/ocr_model/utils/get.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from ....globals import VOCAB_SIZE
|
||||||
|
from typing import (
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaModel,
|
||||||
|
RobertaTokenizerFast
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder():
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer() -> RobertaTokenizerFast:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def get_decoder() -> RobertaModel:
|
||||||
|
configuration = RobertaConfig(
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
is_decoder=True
|
||||||
|
)
|
||||||
|
model = RobertaModel(configuration)
|
||||||
|
return model
|
||||||
|
|
||||||
0
src/models/ocr_model/utils/transforms.py
Normal file
0
src/models/ocr_model/utils/transforms.py
Normal file
Reference in New Issue
Block a user