From 3746ddd4279ba1755be774c048c020f7647feb1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Thu, 18 Apr 2024 00:06:05 +0800 Subject: [PATCH] update infer_det.py --- src/infer_det.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/infer_det.py b/src/infer_det.py index cc5da44..2410b81 100644 --- a/src/infer_det.py +++ b/src/infer_det.py @@ -9,7 +9,6 @@ from tqdm import tqdm from models.det_model.preprocess import Compose import cv2 -# 注意:文件名要标准,最好都用下划线 # Global dictionary SUPPORT_MODELS = { @@ -85,9 +84,7 @@ class PredictConfig(object): self.nms = yml_conf.get("NMS", None) self.fpn_stride = yml_conf.get("fpn_stride", None) - # 预定义颜色池 color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] - # 根据label_list动态生成颜色映射 self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): @@ -120,9 +117,7 @@ def draw_bbox(image, outputs, infer_config): for output in outputs: cls_id, score, xmin, ymin, xmax, ymax = output if score > infer_config.draw_threshold: - # 获取类别名 label = infer_config.label_list[int(cls_id)] - # 根据类别名获取颜色 color = infer_config.colors[label] cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) cv2.putText(image, "{}: {:.2f}".format(label, score),