Update infer_det.py
增加使用gpu进行onnx模型推理的功能
This commit is contained in:
@@ -3,7 +3,7 @@ import argparse
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
import onnxruntime
|
||||
from pathlib import Path
|
||||
|
||||
from models.det_model.inference import PredictConfig, predict_image
|
||||
@@ -17,6 +17,7 @@ parser.add_argument('--onnx_file', type=str, help="onnx model file path",
|
||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
||||
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True)
|
||||
|
||||
|
||||
def get_test_images(infer_dir, infer_img):
|
||||
@@ -71,8 +72,11 @@ if __name__ == '__main__':
|
||||
|
||||
# load image list
|
||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||
# load predictor
|
||||
predictor = InferenceSession(FLAGS.onnx_file)
|
||||
|
||||
if FLAGS.use_gpu:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
|
||||
# load infer config
|
||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user