暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

目标检测中的非极大抑制(NMS)+PyTorch实现

南极Python 2021-05-07
1066

什么是非极大抑制?

在目标检测中,为了提升召回率,通常的检测结果中会出现一个目标对应多个框的结果。

每个框对应的数字代表置信度,通俗来说,就是代表有多大把握确定框内有目标,取值越大,说明越有把握。

但是,对于我们来说,每个目标只需要一个对应框就足够了,因此需要从这么多框中选出最好的,非极大抑制(NMS)就是用来解决这个问题的。

对于只含有一类目标的图片来说,非极大抑制的步骤如下:

(0)设定一个置信度阈值thres和一个IoU阈值iou_thres,删除置信度小于thres的框;

(1)按照置信度对这些框从大到小排序;

(2)取出置信度最大的框放在一边,并将此框与其余的框求IoU,将IoU大于IoU_thres的框删除;

(3) 重复(1)(2),最后剩余的框(也就是取出来的框)就是非极大抑制得到的结果。

注意,以上步骤中,最后得到的结果可能不止一个,因为一张图片中可能有多个同类别目标。

如果是一张图片中含有不止一个类别的目标,只需分别对每个类别的目标重复上面的步骤即可。

下面通过代码实现来加深理解。

非极大抑制的PyTorch实现

关于代码的详细解释已写在注释中:

def nms(bboxes,iou_thres,thres):
    #bboxes:[[1,0.8,x1,y1,x2,y2],...],预测的框的集合,每一个子列表代表一个框
    #其中,
    #bboxes[0][0]:预测类别
    #bboxes[0][1]:置信度
    #bboxes[0][2:]:左上角和右下角坐标
    assert type(bboxes)==list
    
    #去掉置信度小于thres的框
    bboxes=[box for box in bboxes if box[1]>thres]
    #将所有的框按照置信度从大到小排序
    bboxes=sorted(bboxes,key=lambda x:x[1],reverse=True)
    
    bboxes_after_nms=[]#存储最终保留下来的框
    
    while bboxes:
        chosen_box=bboxes.pop(0)#拿出当前置信度最大的框
        bboxes_after_nms.append(chosen_box)#将它保存起来,用于返回
        #更新候选框的集合:
        #某个框和刚刚拿出去的框不是同一类别,则保留;
        #某个框和刚刚拿出去的框是同一类别,但是两者iou小于预先设定的iou_thres,也保留。此时该框可能是同一类别的另外一个目标(图片中有多个同类型目标)
        bboxes=[box for box in bboxes if box[0]!=chosen_box[0] \
                or insert_over_union(torch.tensor(chosen_box[2:]),torch.tensor(box[2:]))<iou_thres]
    return bboxes_after_nms

其中用于计算IoU的函数已经在这篇文章中介绍过,这里直接搬过来了

def insert_over_union(boxes_preds,boxes_labels):
    
    box1_x1=boxes_preds[...,0:1]
    box1_y1=boxes_preds[...,1:2]
    box1_x2=boxes_preds[...,2:3]
    box1_y2=boxes_preds[...,3:4]#shape:[N,1]
    
    box2_x1=boxes_labels[...,0:1]
    box2_y1=boxes_labels[...,1:2]
    box2_x2=boxes_labels[...,2:3]
    box2_y2=boxes_labels[...,3:4]
    
    x1=torch.max(box1_x1,box2_x1)
    y1=torch.max(box1_y1,box2_y1)
    x2=torch.min(box1_x2,box2_x2)
    y2=torch.min(box1_y2,box2_y2)
    
    
    #计算交集区域面积
    intersection=(x2-x1).clamp(0)*(y2-y1).clamp(0)
    
    box1_area=abs((box1_x2-box1_x1)*(box1_y1-box1_y2))
    box2_area=abs((box2_x2-box2_x1)*(box2_y1-box2_y2))
    
    return intersection/(box1_area+box2_area-intersection+1e-6)

测试一下:

参考:

  • [1] https://www.youtube.com/watch?v=YDkjWEN8jNA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=45
  • [2] https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c


重磅!南极Python交流群已成立,添加下方微信,备注加群即可进群。


                             感谢点赞,分享和在看的你!


文章转载自南极Python,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论