写完了模型代码、Tokenizer、数据预处理、训练脚本,但目前的训练脚本没有配置generate(评估仅能看loss)
This commit is contained in:
@@ -4,60 +4,43 @@ from ....globals import (
|
||||
OCR_IMG_CHANNELS
|
||||
)
|
||||
|
||||
from typing import (
|
||||
Tuple
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
DeiTConfig,
|
||||
DeiTModel,
|
||||
ViTConfig,
|
||||
ViTModel,
|
||||
|
||||
TrOCRConfig,
|
||||
TrOCRForCausalLM,
|
||||
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaTokenizerFast,
|
||||
|
||||
VisionEncoderDecoderConfig,
|
||||
VisionEncoderDecoderModel
|
||||
)
|
||||
|
||||
|
||||
class TexTeller:
|
||||
def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None):
|
||||
self.tokenizer = self.get_tokenizer(tokenizer_path)
|
||||
|
||||
assert not (encoder_path is None and decoder_path is not None)
|
||||
assert not (encoder_path is not None and decoder_path is None)
|
||||
|
||||
if encoder_path is None:
|
||||
encoder_config = DeiTConfig(
|
||||
img_size=OCR_IMG_SIZE,
|
||||
num_channels=OCR_IMG_CHANNELS
|
||||
)
|
||||
|
||||
decoder_config = RobertaConfig(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
is_decoder=True
|
||||
)
|
||||
|
||||
model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
|
||||
encoder_config,
|
||||
decoder_config
|
||||
)
|
||||
self.model = VisionEncoderDecoderModel(model_config)
|
||||
|
||||
else:
|
||||
self.model = VisionEncoderDecoderModel.from_pretrained(
|
||||
encoder_path,
|
||||
decoder_path
|
||||
)
|
||||
|
||||
|
||||
...
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
def __init__(self, decoder_path=None, tokenizer_path=None):
|
||||
encoder = ViTModel(ViTConfig(
|
||||
image_size=OCR_IMG_SIZE,
|
||||
num_channels=OCR_IMG_CHANNELS
|
||||
))
|
||||
decoder = TrOCRForCausalLM(TrOCRConfig(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
))
|
||||
super().__init__(encoder=encoder, decoder=decoder)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None:
|
||||
return RobertaTokenizerFast()
|
||||
else:
|
||||
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||
|
||||
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
||||
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
texteller = TexTeller()
|
||||
tokenizer = texteller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
foo = ["Hello, my name is LHY.", "I am a researcher at the University of Science and Technology of China."]
|
||||
bar = tokenizer(foo, return_special_tokens_mask=True)
|
||||
pause = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user