Update inference.py

增加了计时功能
This commit is contained in:
TonyLee1256
2024-05-09 00:20:32 +08:00
committed by GitHub
parent 48043d11e3
commit fe7e4a7af0

View File

@@ -1,4 +1,5 @@
import os import os
import time
import yaml import yaml
import numpy as np import numpy as np
import cv2 import cv2
@@ -90,6 +91,9 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
subimg_save_dir = os.path.join(imgsave_dir, 'subimages') subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True) os.makedirs(subimg_save_dir, exist_ok=True)
first_image_skipped = False
total_time = 0
num_images = 0
# predict image # predict image
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
img = cv2.imread(img_path) img = cv2.imread(img_path)
@@ -102,8 +106,21 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
inputs_name = [var.name for var in predictor.get_inputs()] inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name} inputs = {k: inputs[k][None, ] for k in inputs_name}
# Start timing
start_time = time.time()
outputs = predictor.run(output_names=None, input_feed=inputs) outputs = predictor.run(output_names=None, input_feed=inputs)
# Stop timing
end_time = time.time()
inference_time = end_time - start_time
if not first_image_skipped:
first_image_skipped = True
else:
total_time += inference_time
num_images += 1
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds")
print("ONNXRuntime predict: ") print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]: if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0])) print(np.array(outputs[0]))
@@ -130,12 +147,29 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
subimg_counter += 1 subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes # Draw bounding boxes and save the image with bounding boxes
img_with_mask = img.copy()
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path)) draw_box_dir = os.path.join(output_dir, 'draw_box')
cv2.imwrite(output_file, img_with_bbox) mask_white_dir = os.path.join(output_dir, 'mask_white')
os.makedirs(draw_box_dir, exist_ok=True)
os.makedirs(mask_white_dir, exist_ok=True)
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
cv2.imwrite(output_file_mask, img_with_mask)
cv2.imwrite(output_file_bbox, img_with_bbox)
avg_time_per_image = total_time / num_images if num_images > 0 else 0
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
print("ErrorImgs:") print("ErrorImgs:")
print(errImgList) print(errImgList)