Created
August 15, 2021 05:12
-
-
Save huytung228/1f64e92dc56520ea07a78feb97544e29 to your computer and use it in GitHub Desktop.
NMS algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from iou import intersection_over_union | |
def nms( | |
predictions, | |
iou_threshold, | |
threshold, | |
box_format='corners' | |
): | |
''' | |
Implement non max supression algorithm | |
Parametters: | |
- predictions: list of prediction. format [[class, probability, box],...] | |
- iou_threshold; iou threshold to remove boxs | |
- threshold probability to remove boxs | |
''' | |
bounding_boxes = [box for box in predictions if box[1] > threshold] | |
bounding_boxes = sorted(bounding_boxes, key=lambda x: x[1], reverse=True) | |
bounding_boxes_after_nms = [] | |
while bounding_boxes: | |
chosen_box = bounding_boxes.pop(0) | |
bounding_boxes = [ | |
box for box in bounding_boxes | |
if box[0] != chosen_box[0] | |
or intersection_over_union( | |
torch.tensor(chosen_box[2:]), | |
torch.tensor(box[2:]), | |
box_format = box_format | |
) < iou_threshold | |
] | |
bounding_boxes_after_nms.append(chosen_box) | |
return bounding_boxes_after_nms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment