diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 25671de..9a0bc8a 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -96,7 +96,7 @@ if __name__ == '__main__': collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # model = TexTeller() - model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') + model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000') # ================= debug ======================= # foo = train_dataset[:50] diff --git a/src/start_web.sh b/src/start_web.sh index 3d700cc..7dab5d2 100755 --- a/src/start_web.sh +++ b/src/start_web.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash set -exu -# export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt" -export CHECKPOINT_DIR="default" +export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-460000" +# export CHECKPOINT_DIR="default" export TOKENIZER_DIR="/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas" export USE_CUDA=True # True or False (case-sensitive) export NUM_BEAM=3