写完了模型代码、Tokenizer、数据预处理、训练脚本,但目前的训练脚本没有配置generate(评估仅能看loss)

This commit is contained in:
三洋三洋
2024-01-28 06:19:23 +00:00
parent 9d27ee0585
commit c6d5c91955
18 changed files with 80058 additions and 78 deletions

View File

@@ -4,60 +4,43 @@ from ....globals import (
OCR_IMG_CHANNELS
)
from typing import (
Tuple
)
from transformers import (
DeiTConfig,
DeiTModel,
ViTConfig,
ViTModel,
TrOCRConfig,
TrOCRForCausalLM,
RobertaConfig,
RobertaModel,
RobertaTokenizerFast,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel
)
class TexTeller:
def __init__(self, encoder_path=None, decoder_path=None, tokenizer_path=None):
self.tokenizer = self.get_tokenizer(tokenizer_path)
assert not (encoder_path is None and decoder_path is not None)
assert not (encoder_path is not None and decoder_path is None)
if encoder_path is None:
encoder_config = DeiTConfig(
img_size=OCR_IMG_SIZE,
num_channels=OCR_IMG_CHANNELS
)
decoder_config = RobertaConfig(
vocab_size=VOCAB_SIZE,
is_decoder=True
)
model_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config,
decoder_config
)
self.model = VisionEncoderDecoderModel(model_config)
else:
self.model = VisionEncoderDecoderModel.from_pretrained(
encoder_path,
decoder_path
)
...
class TexTeller(VisionEncoderDecoderModel):
def __init__(self, decoder_path=None, tokenizer_path=None):
encoder = ViTModel(ViTConfig(
image_size=OCR_IMG_SIZE,
num_channels=OCR_IMG_CHANNELS
))
decoder = TrOCRForCausalLM(TrOCRConfig(
vocab_size=VOCAB_SIZE,
))
super().__init__(encoder=encoder, decoder=decoder)
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
@classmethod
def get_tokenizer(tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None:
return RobertaTokenizerFast()
else:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
if __name__ == "__main__":
texteller = TexTeller()
tokenizer = texteller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
foo = ["Hello, my name is LHY.", "I am a researcher at the University of Science and Technology of China."]
bar = tokenizer(foo, return_special_tokens_mask=True)
pause = 1