Update files
This commit is contained in:
@@ -30,14 +30,14 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None):
|
||||
if model_path is None or model_path == cls.REPO_NAME:
|
||||
if model_path is None or model_path == 'default':
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None or tokenizer_path == cls.REPO_NAME:
|
||||
if tokenizer_path is None or tokenizer_path == 'default':
|
||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
|
||||
Reference in New Issue
Block a user