完成了web,ray server,重构了代码

This commit is contained in:
三洋三洋
2024-02-08 13:48:34 +00:00
parent 07c4c3dc01
commit 04b99b8451
20 changed files with 245 additions and 57 deletions

View File

@@ -1,6 +1,7 @@
from PIL import Image
from pathlib import Path
from ....globals import (
from models.globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS,
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')