update infer_det.py
This commit is contained in:
@@ -9,7 +9,6 @@ from tqdm import tqdm
|
||||
from models.det_model.preprocess import Compose
|
||||
import cv2
|
||||
|
||||
# 注意:文件名要标准,最好都用下划线
|
||||
|
||||
# Global dictionary
|
||||
SUPPORT_MODELS = {
|
||||
@@ -85,9 +84,7 @@ class PredictConfig(object):
|
||||
self.nms = yml_conf.get("NMS", None)
|
||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
||||
|
||||
# 预定义颜色池
|
||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||
# 根据label_list动态生成颜色映射
|
||||
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
|
||||
|
||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||
@@ -120,9 +117,7 @@ def draw_bbox(image, outputs, infer_config):
|
||||
for output in outputs:
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
# 获取类别名
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
# 根据类别名获取颜色
|
||||
color = infer_config.colors[label]
|
||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
||||
cv2.putText(image, "{}: {:.2f}".format(label, score),
|
||||
|
||||
Reference in New Issue
Block a user