Skip to content

Instantly share code, notes, and snippets.

@huytung228
Created August 15, 2021 05:12
Show Gist options
  • Save huytung228/1f64e92dc56520ea07a78feb97544e29 to your computer and use it in GitHub Desktop.
Save huytung228/1f64e92dc56520ea07a78feb97544e29 to your computer and use it in GitHub Desktop.
NMS algorithm
import torch
from iou import intersection_over_union
def nms(
predictions,
iou_threshold,
threshold,
box_format='corners'
):
'''
Implement non max supression algorithm
Parametters:
- predictions: list of prediction. format [[class, probability, box],...]
- iou_threshold; iou threshold to remove boxs
- threshold probability to remove boxs
'''
bounding_boxes = [box for box in predictions if box[1] > threshold]
bounding_boxes = sorted(bounding_boxes, key=lambda x: x[1], reverse=True)
bounding_boxes_after_nms = []
while bounding_boxes:
chosen_box = bounding_boxes.pop(0)
bounding_boxes = [
box for box in bounding_boxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format = box_format
) < iou_threshold
]
bounding_boxes_after_nms.append(chosen_box)
return bounding_boxes_after_nms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment