完成了web,ray server,重构了代码
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from ....globals import (
|
||||
from models.globals import (
|
||||
VOCAB_SIZE,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_IMG_CHANNELS,
|
||||
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
||||
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# texteller = TexTeller()
|
||||
from ..inference import inference
|
||||
from ..utils.inference import inference
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user