Update infer_det.py
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from pathlib import Path
|
||||
|
||||
from models.det_model.inference import PredictConfig, predict_image
|
||||
|
||||
|
||||
@@ -12,8 +14,8 @@ parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
|
||||
default="./models/det_model/model/infer_cfg.yml")
|
||||
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
|
||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--image_file", type=str, required=True)
|
||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||
|
||||
|
||||
@@ -47,6 +49,10 @@ def get_test_images(infer_dir, infer_img):
|
||||
|
||||
return images
|
||||
|
||||
def download_file(url, filename):
|
||||
print(f"Downloading {filename}...")
|
||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
||||
print("Download complete.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
cur_path = os.getcwd()
|
||||
@@ -54,6 +60,15 @@ if __name__ == '__main__':
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if not os.path.exists(FLAGS.infer_cfg):
|
||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
||||
|
||||
if not os.path.exists(FLAGS.onnx_file):
|
||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
||||
|
||||
# load image list
|
||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||
# load predictor
|
||||
|
||||
Reference in New Issue
Block a user