diff --git a/src/infer_det.py b/src/infer_det.py index b90a1c8..00baf9e 100644 --- a/src/infer_det.py +++ b/src/infer_det.py @@ -3,7 +3,7 @@ import argparse import glob import subprocess -from onnxruntime import InferenceSession +import onnxruntime from pathlib import Path from models.det_model.inference import PredictConfig, predict_image @@ -17,6 +17,7 @@ parser.add_argument('--onnx_file', type=str, help="onnx model file path", parser.add_argument("--image_dir", type=str, default='./testImgs') parser.add_argument("--image_file", type=str) parser.add_argument("--imgsave_dir", type=str, default="./detect_results") +parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True) def get_test_images(infer_dir, infer_img): @@ -71,8 +72,11 @@ if __name__ == '__main__': # load image list img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - # load predictor - predictor = InferenceSession(FLAGS.onnx_file) + + if FLAGS.use_gpu: + predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider']) + else: + predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider']) # load infer config infer_config = PredictConfig(FLAGS.infer_cfg)