diff --git a/src/infer_det.py b/src/infer_det.py index cc5da44..2410b81 100644 --- a/src/infer_det.py +++ b/src/infer_det.py @@ -9,7 +9,6 @@ from tqdm import tqdm from models.det_model.preprocess import Compose import cv2 -# 注意:文件名要标准,最好都用下划线 # Global dictionary SUPPORT_MODELS = { @@ -85,9 +84,7 @@ class PredictConfig(object): self.nms = yml_conf.get("NMS", None) self.fpn_stride = yml_conf.get("fpn_stride", None) - # 预定义颜色池 color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] - # 根据label_list动态生成颜色映射 self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): @@ -120,9 +117,7 @@ def draw_bbox(image, outputs, infer_config): for output in outputs: cls_id, score, xmin, ymin, xmax, ymax = output if score > infer_config.draw_threshold: - # 获取类别名 label = infer_config.label_list[int(cls_id)] - # 根据类别名获取颜色 color = infer_config.colors[label] cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) cv2.putText(image, "{}: {:.2f}".format(label, score),