From 48043d11e380a0b991542c7d66e8db6426b1a7ec Mon Sep 17 00:00:00 2001 From: TonyLee1256 <163754792+TonyLee1256@users.noreply.github.com> Date: Thu, 9 May 2024 00:19:39 +0800 Subject: [PATCH] Update infer_det.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加使用gpu进行onnx模型推理的功能 --- src/infer_det.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/infer_det.py b/src/infer_det.py index b90a1c8..00baf9e 100644 --- a/src/infer_det.py +++ b/src/infer_det.py @@ -3,7 +3,7 @@ import argparse import glob import subprocess -from onnxruntime import InferenceSession +import onnxruntime from pathlib import Path from models.det_model.inference import PredictConfig, predict_image @@ -17,6 +17,7 @@ parser.add_argument('--onnx_file', type=str, help="onnx model file path", parser.add_argument("--image_dir", type=str, default='./testImgs') parser.add_argument("--image_file", type=str) parser.add_argument("--imgsave_dir", type=str, default="./detect_results") +parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True) def get_test_images(infer_dir, infer_img): @@ -71,8 +72,11 @@ if __name__ == '__main__': # load image list img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - # load predictor - predictor = InferenceSession(FLAGS.onnx_file) + + if FLAGS.use_gpu: + predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider']) + else: + predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider']) # load infer config infer_config = PredictConfig(FLAGS.infer_cfg)