Skip to content

Instantly share code, notes, and snippets.

@akbargumbira
Created July 17, 2018 10:22
Show Gist options
  • Save akbargumbira/d51fc9f6db7497d90db5b8d48aae194b to your computer and use it in GitHub Desktop.
Save akbargumbira/d51fc9f6db7497d90db5b8d48aae194b to your computer and use it in GitHub Desktop.
Faster NMS
# coding=utf-8
import copy
import time
import numpy as np
import cv2
from lib.yolo.detector import YOLO
from lib.yolo.utilities import root_path, BoundBox, bbox_iou, decode_netout, correct_yolo_boxes
def old_do_nms(boxes, nms_thresh):
if len(boxes) > 0:
nb_class = len(boxes[0].classes)
else:
return
for c in range(nb_class):
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for i in range(len(sorted_indices)):
index_i = sorted_indices[i]
if boxes[index_i].classes[c] == 0: continue
for j in range(i + 1, len(sorted_indices)):
index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0
def new_do_nms(boxes, nms_thresh):
if len(boxes) > 0:
nb_class = len(boxes[0].classes)
else:
return
for c in range(nb_class):
class_max_boxes = np.max([box.classes[c] for box in boxes])
if class_max_boxes == 0: continue
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for i in range(len(sorted_indices)):
index_i = sorted_indices[i]
if boxes[index_i].classes[c] == 0: continue
for j in range(i + 1, len(sorted_indices)):
index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0
weight_path = root_path('yolo', 'model', 'weights_fashion_v31.pb')
yolo = YOLO(weight_path)
img_path = root_path('yolo', 'test', 'images', 'n03589791_10279.JPEG')
image = cv2.imread(img_path)
image_h, image_w, _ = image.shape
output = yolo.predict(image, decode=False)
boxes = []
for i in range(len(output)):
yolo_anchors = yolo._anchors[(2 - i) * 6:(3 - i) * 6]
boxes += decode_netout(
np.squeeze(output[i], axis=0),
yolo_anchors,
yolo._obj_threshold,
yolo._net_h,
yolo._net_w)
# correct the sizes of the bounding boxes
correct_yolo_boxes(boxes, image_h, image_w, yolo._net_h, yolo._net_w)
copy_boxes = copy.deepcopy(boxes)
start = time.time()
for i in range(1000):
new_do_nms(copy_boxes, 0.1)
print(time.time() - start)
start = time.time()
for i in range(1000):
old_do_nms(boxes, 0.1)
print(time.time() - start)
0.5759725570678711
1.6929998397827148
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment