Update infer_det.py

This commit is contained in:
TonyLee1256
2024-04-22 00:07:41 +08:00
committed by GitHub
parent be19ed8d63
commit 0bb11bebfc

View File

@@ -1,9 +1,11 @@
import os import os
import argparse import argparse
import glob import glob
import subprocess
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from pathlib import Path from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image from models.det_model.inference import PredictConfig, predict_image
@@ -12,8 +14,8 @@ parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
default="./models/det_model/model/infer_cfg.yml") default="./models/det_model/model/infer_cfg.yml")
parser.add_argument('--onnx_file', type=str, help="onnx model file path", parser.add_argument('--onnx_file', type=str, help="onnx model file path",
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str, default='./testImgs')
parser.add_argument("--image_file", type=str, required=True) parser.add_argument("--image_file", type=str)
parser.add_argument("--imgsave_dir", type=str, default="./detect_results") parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
@@ -47,6 +49,10 @@ def get_test_images(infer_dir, infer_img):
return images return images
def download_file(url, filename):
print(f"Downloading {filename}...")
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
print("Download complete.")
if __name__ == '__main__': if __name__ == '__main__':
cur_path = os.getcwd() cur_path = os.getcwd()
@@ -54,6 +60,15 @@ if __name__ == '__main__':
os.chdir(script_dirpath) os.chdir(script_dirpath)
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
if not os.path.exists(FLAGS.infer_cfg):
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
download_file(infer_cfg_url, FLAGS.infer_cfg)
if not os.path.exists(FLAGS.onnx_file):
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
download_file(onnx_file_url, FLAGS.onnx_file)
# load image list # load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
# load predictor # load predictor