diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index 1982cc3..7c014ae 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -187,7 +187,6 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: assert IMG_CHANNELS == 1 , "Only support grayscale images for now" - images = [np.array(img.convert('RGB')) for img in images] # 裁剪掉白边 images = [trim_white_border(image) for image in images] # general transform pipeline diff --git a/src/start_web.sh b/src/start_web.sh index 450dff2..7475147 100755 --- a/src/start_web.sh +++ b/src/start_web.sh @@ -1,9 +1,9 @@ #!/usr/bin/env bash set -exu -export CHECKPOINT_DIR="default" -export TOKENIZER_DIR="default" -export USE_CUDA=False # True or False (case-sensitive) -export NUM_BEAM=1 +export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt" +export TOKENIZER_DIR="/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas" +export USE_CUDA=True # True or False (case-sensitive) +export NUM_BEAM=3 streamlit run web.py