55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
from PIL import Image
|
|
|
|
from ....globals import (
|
|
VOCAB_SIZE,
|
|
OCR_IMG_SIZE,
|
|
OCR_IMG_CHANNELS,
|
|
)
|
|
|
|
from transformers import (
|
|
ViTConfig,
|
|
ViTModel,
|
|
TrOCRConfig,
|
|
TrOCRForCausalLM,
|
|
RobertaTokenizerFast,
|
|
VisionEncoderDecoderModel,
|
|
)
|
|
|
|
|
|
class TexTeller(VisionEncoderDecoderModel):
|
|
def __init__(self, decoder_path=None, tokenizer_path=None):
|
|
encoder = ViTModel(ViTConfig(
|
|
image_size=OCR_IMG_SIZE,
|
|
num_channels=OCR_IMG_CHANNELS
|
|
))
|
|
decoder = TrOCRForCausalLM(TrOCRConfig(
|
|
vocab_size=VOCAB_SIZE,
|
|
))
|
|
super().__init__(encoder=encoder, decoder=decoder)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: str):
|
|
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
|
|
|
@classmethod
|
|
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
|
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# texteller = TexTeller()
|
|
from ..inference import inference
|
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
|
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
|
|
|
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
|
|
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
|
|
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
|
|
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
|
|
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
|
|
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
|
|
|
|
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
|
|
pause = 1
|
|
|