Created
August 15, 2021 05:09
-
-
Save huytung228/644285317fa317516bc14b83727aaa0d to your computer and use it in GitHub Desktop.
Calculate mAP object detection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from collections import Counter | |
from iou import intersection_over_union | |
def mean_average_precision( | |
pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20 | |
): | |
""" | |
Calculates mean average precision | |
Parameters: | |
pred_boxes (list): list of lists containing all bboxes with each bboxes | |
specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2] | |
true_boxes (list): Similar as pred_boxes except all the correct ones | |
iou_threshold (float): threshold where predicted bboxes is correct | |
box_format (str): "midpoint" or "corners" used to specify bboxes | |
num_classes (int): number of classes | |
Returns: | |
float: mAP value across all classes given a specific IoU threshold | |
""" | |
# list storing all AP for respective classes | |
average_precisions = [] | |
# used for numerical stability later on | |
epsilon = 1e-6 | |
for c in range(num_classes): | |
detections = [] | |
ground_truths = [] | |
# Go through all predictions and targets, | |
# and only add the ones that belong to the | |
# current class c | |
for detection in pred_boxes: | |
if detection[1] == c: | |
detections.append(detection) | |
for true_box in true_boxes: | |
if true_box[1] == c: | |
ground_truths.append(true_box) | |
# find the amount of bboxes for each training example | |
# Counter here finds how many ground truth bboxes we get | |
# for each training example, so let's say img 0 has 3, | |
# img 1 has 5 then we will obtain a dictionary with: | |
# amount_bboxes = {0:3, 1:5} | |
amount_bboxes = Counter([gt[0] for gt in ground_truths]) | |
# We then go through each key, val in this dictionary | |
# and convert to the following (w.r.t same example): | |
# ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]} | |
for key, val in amount_bboxes.items(): | |
amount_bboxes[key] = torch.zeros(val) | |
# sort by box probabilities which is index 2 | |
detections.sort(key=lambda x: x[2], reverse=True) | |
TP = torch.zeros((len(detections))) | |
FP = torch.zeros((len(detections))) | |
total_true_bboxes = len(ground_truths) | |
# If none exists for this class then we can safely skip | |
if total_true_bboxes == 0: | |
continue | |
for detection_idx, detection in enumerate(detections): | |
# Only take out the ground_truths that have the same | |
# training idx as detection | |
ground_truth_img = [ | |
bbox for bbox in ground_truths if bbox[0] == detection[0] | |
] | |
num_gts = len(ground_truth_img) | |
best_iou = 0 | |
for idx, gt in enumerate(ground_truth_img): | |
iou = intersection_over_union( | |
torch.tensor(detection[3:]), | |
torch.tensor(gt[3:]), | |
box_format=box_format, | |
) | |
if iou > best_iou: | |
best_iou = iou | |
best_gt_idx = idx | |
if best_iou > iou_threshold: | |
# only detect ground truth detection once | |
if amount_bboxes[detection[0]][best_gt_idx] == 0: | |
# true positive and add this bounding box to seen | |
TP[detection_idx] = 1 | |
amount_bboxes[detection[0]][best_gt_idx] = 1 | |
else: | |
FP[detection_idx] = 1 | |
# if IOU is lower then the detection is a false positive | |
else: | |
FP[detection_idx] = 1 | |
TP_cumsum = torch.cumsum(TP, dim=0) | |
FP_cumsum = torch.cumsum(FP, dim=0) | |
recalls = TP_cumsum / (total_true_bboxes + epsilon) | |
precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon) | |
precisions = torch.cat((torch.tensor([1]), precisions)) | |
recalls = torch.cat((torch.tensor([0]), recalls)) | |
# torch.trapz for numerical integration | |
average_precisions.append(torch.trapz(precisions, recalls)) | |
return sum(average_precisions) / len(average_precisions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment