inference_transform bugfix

This commit is contained in:
三洋三洋
2024-04-06 05:09:50 +00:00
parent 87ddb86e5e
commit c9c15d27bd
2 changed files with 4 additions and 5 deletions

View File

@@ -187,7 +187,6 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now" assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# 裁剪掉白边 # 裁剪掉白边
images = [trim_white_border(image) for image in images] images = [trim_white_border(image) for image in images]
# general transform pipeline # general transform pipeline

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -exu set -exu
export CHECKPOINT_DIR="default" export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt"
export TOKENIZER_DIR="default" export TOKENIZER_DIR="/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas"
export USE_CUDA=False # True or False (case-sensitive) export USE_CUDA=True # True or False (case-sensitive)
export NUM_BEAM=1 export NUM_BEAM=3
streamlit run web.py streamlit run web.py