update infer_det.py

This commit is contained in:
三洋三洋
2024-04-18 00:06:05 +08:00
parent d5eca45fcc
commit 3746ddd427

View File

@@ -9,7 +9,6 @@ from tqdm import tqdm
from models.det_model.preprocess import Compose
import cv2
# 注意:文件名要标准,最好都用下划线
# Global dictionary
SUPPORT_MODELS = {
@@ -85,9 +84,7 @@ class PredictConfig(object):
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
# 预定义颜色池
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
# 根据label_list动态生成颜色映射
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
@@ -120,9 +117,7 @@ def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
# 获取类别名
label = infer_config.label_list[int(cls_id)]
# 根据类别名获取颜色
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),