Skip to content

Instantly share code, notes, and snippets.

@huytung228
Created August 15, 2021 05:11
Show Gist options
  • Save huytung228/87491b3139eb22058980a77c4906aafc to your computer and use it in GitHub Desktop.
Save huytung228/87491b3139eb22058980a77c4906aafc to your computer and use it in GitHub Desktop.
calculate IOU
import torch
def intersection_over_union(prediction_box, label_box, box_format='corners'):
'''
Funtion to calculate IOU between boxs
Parametters:
- prediction_box: a tensor with shape (N,4) - N boxs
- label_box: a tensor with shape (N,4) - N boxs
- box_format: enum between 'corner' and 'midpoint'
'''
if box_format == 'corners':
pred_box_x1 = prediction_box[..., 0:1]
pred_box_y1 = prediction_box[..., 1:2]
pred_box_x2 = prediction_box[..., 2:3]
pred_box_y2 = prediction_box[..., 3:4]
label_box_x1 = label_box[..., 0:1]
label_box_y1 = label_box[..., 1:2]
label_box_x2 = label_box[..., 2:3]
label_box_y2 = label_box[..., 3:4]
else:
pred_center_x = prediction_box[..., 0:1]
pred_center_y = prediction_box[..., 1:2]
pred_width = prediction_box[..., 2:3]
pred_height = prediction_box[..., 3:4]
label_center_x = label_box[..., 0:1]
label_center_y = label_box[..., 1:2]
label_width = label_box[..., 2:3]
label_height = label_box[..., 3:4]
pred_box_x1 = pred_center_x - pred_width/2
pred_box_y1 = pred_center_y - pred_height/2
pred_box_x2 = pred_center_x + pred_width/2
pred_box_y2 = pred_center_y + pred_height/2
label_box_x1 = label_center_x - label_width/2
label_box_y1 = label_center_y - label_height/2
label_box_x2 = label_center_x + label_width/2
label_box_y2 = label_center_y + label_height/2
# Find intersection coordinate
intersection_x1 = torch.max(pred_box_x1, label_box_x1)
intersection_y1 = torch.max(pred_box_y1, label_box_y1)
intersection_x2 = torch.min(pred_box_x2, label_box_x2)
intersection_y2 = torch.min(pred_box_y2, label_box_y2)
# Find intersection area
# clamp 0 in case of they do not intersect
intersection = (intersection_x2 - intersection_x1).clamp(0) * \
(intersection_y2 - intersection_y1).clamp(0)
pred_box_area = abs((pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1))
label_box_area = abs((label_box_x2 - label_box_x1) * (label_box_y2 - label_box_y1))
union = pred_box_area + label_box_area - intersection
return intersection / (union + 1e-6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment