Files
TexTeller/src/models/ocr_model/model/TexTeller.py
2024-01-28 14:03:42 +00:00

55 lines
1.8 KiB
Python

from PIL import Image
from ....globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS,
)
from transformers import (
ViTConfig,
ViTModel,
TrOCRConfig,
TrOCRForCausalLM,
RobertaTokenizerFast,
VisionEncoderDecoderModel,
)
class TexTeller(VisionEncoderDecoderModel):
def __init__(self, decoder_path=None, tokenizer_path=None):
encoder = ViTModel(ViTConfig(
image_size=OCR_IMG_SIZE,
num_channels=OCR_IMG_CHANNELS
))
decoder = TrOCRForCausalLM(TrOCRConfig(
vocab_size=VOCAB_SIZE,
))
super().__init__(encoder=encoder, decoder=decoder)
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
@classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
pause = 1