diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index 63273a8..92bc08c 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -13,7 +13,7 @@ from models.globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, - imgs_path: Union[List[str], List[np.ndarray]], + imgs: Union[List[str], List[np.ndarray]], inf_mode: str = 'cpu', num_beams: int = 1, ) -> List[str]: diff --git a/src/start_web.sh b/src/start_web.sh index 6ec8f7b..41e6311 100755 --- a/src/start_web.sh +++ b/src/start_web.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -exu -export CHECKPOINT_DIR="default" +export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-648000" export TOKENIZER_DIR="default" streamlit run web.py