Skip to content

Instantly share code, notes, and snippets.

@VinhDevNguyen
Last active October 23, 2021 13:04
Show Gist options
  • Save VinhDevNguyen/b26f67c2ffaf36775ae0301c073a26f1 to your computer and use it in GitHub Desktop.
Save VinhDevNguyen/b26f67c2ffaf36775ae0301c073a26f1 to your computer and use it in GitHub Desktop.

Learning To Count Everything

Bài tập 1

Mã nguồn xử lí dữ liệu

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/utils.py#L300-L303

Normalize = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)])
Transform = transforms.Compose([resizeImage( MAX_HW)])
TransformTrain = transforms.Compose([resizeImageWithGT(MAX_HW)])

Xử lí trong train

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/train.py#L91-L93

sample = {'image':image,'lines_boxes':rects,'gt_density':density}
sample = TransformTrain(sample)
image, boxes,gt_density = sample['image'].cuda(), sample['boxes'].cuda(),sample['gt_density'].cuda()

Xử lí trong test

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/test.py#L91-L95

image = Image.open('{}/{}'.format(im_dir, im_id))
image.load()
sample = {'image': image, 'lines_boxes': rects}
sample = Transform(sample)
image, boxes = sample['image'], sample['boxes']

Xử lí trong eval

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/train.py#L147-L151

image = Image.open('{}/{}'.format(im_dir, im_id))
image.load()
sample = {'image':image,'lines_boxes':rects}
sample = Transform(sample)
image, boxes = sample['image'].cuda(), sample['boxes'].cuda()

Xử lí trong demo

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/demo.py#L93-L97

image = Image.open(args.input_image)
image.load()
sample = {'image': image, 'lines_boxes': rects1}
sample = Transform(sample)
image, boxes = sample['image'], sample['boxes']

Class resizeImage() tiền xử lí cho eval, test, demo

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/utils.py#L216-L251

class resizeImage(object):
    """
    If either the width or height of an image exceed a specified value, resize the image so that:
        1. The maximum of the new height and new width does not exceed a specified value
        2. The new height and new width are divisible by 8
        3. The aspect ratio is preserved
    No resizing is done if both height and width are smaller than the specified value
    By: Minh Hoai Nguyen (minhhoai@gmail.com)
    """
    
    def __init__(self, MAX_HW=1504):
        self.max_hw = MAX_HW


    def __call__(self, sample):
        image,lines_boxes = sample['image'], sample['lines_boxes']
        
        W, H = image.size
        if W > self.max_hw or H > self.max_hw:
            scale_factor = float(self.max_hw)/ max(H, W)
            new_H = 8*int(H*scale_factor/8)
            new_W = 8*int(W*scale_factor/8)
            resized_image = transforms.Resize((new_H, new_W))(image)
        else:
            scale_factor = 1
            resized_image = image


        boxes = list()
        for box in lines_boxes:
            box2 = [int(k*scale_factor) for k in box]
            y1, x1, y2, x2 = box2[0], box2[1], box2[2], box2[3]
            boxes.append([0, y1,x1,y2,x2])


        boxes = torch.Tensor(boxes).unsqueeze(0)
        resized_image = Normalize(resized_image)
        sample = {'image':resized_image,'boxes':boxes}
        return sample

Class resizeImageWithGT() tiền xử lí cho train

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/utils.py#L254-L297

class resizeImageWithGT(object):
    """
    If either the width or height of an image exceed a specified value, resize the image so that:
        1. The maximum of the new height and new width does not exceed a specified value
        2. The new height and new width are divisible by 8
        3. The aspect ratio is preserved
    No resizing is done if both height and width are smaller than the specified value
    By: Minh Hoai Nguyen (minhhoai@gmail.com)
    Modified by: Viresh
    """
    
    def __init__(self, MAX_HW=1504):
        self.max_hw = MAX_HW


    def __call__(self, sample):
        image,lines_boxes,density = sample['image'], sample['lines_boxes'],sample['gt_density']
        
        W, H = image.size
        if W > self.max_hw or H > self.max_hw:
            scale_factor = float(self.max_hw)/ max(H, W)
            new_H = 8*int(H*scale_factor/8)
            new_W = 8*int(W*scale_factor/8)
            resized_image = transforms.Resize((new_H, new_W))(image)
            resized_density = cv2.resize(density, (new_W, new_H))
            orig_count = np.sum(density)
            new_count = np.sum(resized_density)


            if new_count > 0: resized_density = resized_density * (orig_count / new_count)
            
        else:
            scale_factor = 1
            resized_image = image
            resized_density = density
        boxes = list()
        for box in lines_boxes:
            box2 = [int(k*scale_factor) for k in box]
            y1, x1, y2, x2 = box2[0], box2[1], box2[2], box2[3]
            boxes.append([0, y1,x1,y2,x2])


        boxes = torch.Tensor(boxes).unsqueeze(0)
        resized_image = Normalize(resized_image)
        resized_density = torch.from_numpy(resized_density).unsqueeze(0).unsqueeze(0)
        sample = {'image':resized_image,'boxes':boxes,'gt_density':resized_density}
        return sample

Mã nguồn model, train, eval

Model

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/train.py#L95-L106

with torch.no_grad():
    features = extract_features(resnet50_conv, image.unsqueeze(0), boxes.unsqueeze(0), MAPS, Scales)
features.requires_grad = True
optimizer.zero_grad()
output = regressor(features)


#if image size isn't divisible by 8, gt size is slightly different from output size
if output.shape[2] != gt_density.shape[2] or output.shape[3] != gt_density.shape[3]:
    orig_count = gt_density.sum().detach().item()
    gt_density = F.interpolate(gt_density, size=(output.shape[2],output.shape[3]),mode='bilinear')
    new_count = gt_density.sum().detach().item()
    if new_count > 0: gt_density = gt_density * (orig_count / new_count)

Train

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/train.py#L172-L189

for epoch in range(0,args.epochs):
    regressor.train()
    train_loss,train_mae,train_rmse = train()
    regressor.eval()
    val_mae,val_rmse = eval()
    stats.append((train_loss, train_mae, train_rmse, val_mae, val_rmse))
    stats_file = join(args.output_dir, "stats" +  ".txt")
    with open(stats_file, 'w') as f:
        for s in stats:
            f.write("%s\n" % ','.join([str(x) for x in s]))    
    if best_mae >= val_mae:
        best_mae = val_mae
        best_rmse = val_rmse
        model_name = args.output_dir + '/' + "FamNet.pth"
        torch.save(regressor.state_dict(), model_name)


    print("Epoch {}, Avg. Epoch Loss: {} Train MAE: {} Train RMSE: {} Val MAE: {} Val RMSE: {} Best Val MAE: {} Best Val RMSE: {} ".format(
              epoch+1,  stats[-1][0], stats[-1][1], stats[-1][2], stats[-1][3], stats[-1][4], best_mae, best_rmse))

Train

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/train.py#L126-L167

def eval():
    cnt = 0
    SAE = 0 # sum of absolute errors
    SSE = 0 # sum of square errors


    print("Evaluation on {} data".format(args.test_split))
    im_ids = data_split[args.test_split]
    pbar = tqdm(im_ids)
    for im_id in pbar:
        anno = annotations[im_id]
        bboxes = anno['box_examples_coordinates']
        dots = np.array(anno['points'])


        rects = list()
        for bbox in bboxes:
            x1 = bbox[0][0]
            y1 = bbox[0][1]
            x2 = bbox[2][0]
            y2 = bbox[2][1]
            rects.append([y1, x1, y2, x2])


        image = Image.open('{}/{}'.format(im_dir, im_id))
        image.load()
        sample = {'image':image,'lines_boxes':rects}
        sample = Transform(sample)
        image, boxes = sample['image'].cuda(), sample['boxes'].cuda()


        with torch.no_grad():
            output = regressor(extract_features(resnet50_conv, image.unsqueeze(0), boxes.unsqueeze(0), MAPS, Scales))


        gt_cnt = dots.shape[0]
        pred_cnt = output.sum().item()
        cnt = cnt + 1
        err = abs(gt_cnt - pred_cnt)
        SAE += err
        SSE += err**2


        pbar.set_description('{:<8}: actual-predicted: {:6d}, {:6.1f}, error: {:6.1f}. Current MAE: {:5.2f}, RMSE: {:5.2f}'.format(im_id, gt_cnt, pred_cnt, abs(pred_cnt - gt_cnt), SAE/cnt, (SSE/cnt)**0.5))
        print("")


    print('On {} data, MAE: {:6.2f}, RMSE: {:6.2f}'.format(args.test_split, SAE/cnt, (SSE/cnt)**0.5))
    return SAE/cnt, (SSE/cnt)**0.5

Bài tập 2: Mapping các module với source

image

Feature Extraction

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/utils.py#L128-L213

def extract_features(feature_model, image, boxes,feat_map_keys=['map3','map4'], exemplar_scales=[0.9, 1.1]):
    N, M = image.shape[0], boxes.shape[2]
    """
    Getting features for the image N * C * H * W
    """
    Image_features = feature_model(image)
    """
    Getting features for the examples (N*M) * C * h * w
    """
    for ix in range(0,N):
        # boxes = boxes.squeeze(0)
        boxes = boxes[ix][0]
        cnter = 0
        Cnter1 = 0
        for keys in feat_map_keys:
            image_features = Image_features[keys][ix].unsqueeze(0)
            if keys == 'map1' or keys == 'map2':
                Scaling = 4.0
            elif keys == 'map3':
                Scaling = 8.0
            elif keys == 'map4':
                Scaling =  16.0
            else:
                Scaling = 32.0
            boxes_scaled = boxes / Scaling
            boxes_scaled[:, 1:3] = torch.floor(boxes_scaled[:, 1:3])
            boxes_scaled[:, 3:5] = torch.ceil(boxes_scaled[:, 3:5])
            boxes_scaled[:, 3:5] = boxes_scaled[:, 3:5] + 1 # make the end indices exclusive 
            feat_h, feat_w = image_features.shape[-2], image_features.shape[-1]
            # make sure exemplars don't go out of bound
            boxes_scaled[:, 1:3] = torch.clamp_min(boxes_scaled[:, 1:3], 0)
            boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], feat_h)
            boxes_scaled[:, 4] = torch.clamp_max(boxes_scaled[:, 4], feat_w)            
            box_hs = boxes_scaled[:, 3] - boxes_scaled[:, 1]
            box_ws = boxes_scaled[:, 4] - boxes_scaled[:, 2]            
            max_h = math.ceil(max(box_hs))
            max_w = math.ceil(max(box_ws))            
            for j in range(0,M):
                y1, x1 = int(boxes_scaled[j,1]), int(boxes_scaled[j,2])  
                y2, x2 = int(boxes_scaled[j,3]), int(boxes_scaled[j,4]) 
                #print(y1,y2,x1,x2,max_h,max_w)
                if j == 0:
                    examples_features = image_features[:,:,y1:y2, x1:x2]
                    if examples_features.shape[2] != max_h or examples_features.shape[3] != max_w:
                        #examples_features = pad_to_size(examples_features, max_h, max_w)
                        examples_features = F.interpolate(examples_features, size=(max_h,max_w),mode='bilinear')                    
                else:
                    feat = image_features[:,:,y1:y2, x1:x2]
                    if feat.shape[2] != max_h or feat.shape[3] != max_w:
                        feat = F.interpolate(feat, size=(max_h,max_w),mode='bilinear')
                        #feat = pad_to_size(feat, max_h, max_w)
                    examples_features = torch.cat((examples_features,feat),dim=0)
            """
            Convolving example features over image features
            """
            h, w = examples_features.shape[2], examples_features.shape[3]
            features =    F.conv2d(
                    F.pad(image_features, ((int(w/2)), int((w-1)/2), int(h/2), int((h-1)/2))),
                    examples_features
                )
            combined = features.permute([1,0,2,3])
            # computing features for scales 0.9 and 1.1 
            for scale in exemplar_scales:
                    h1 = math.ceil(h * scale)
                    w1 = math.ceil(w * scale)
                    if h1 < 1: # use original size if scaled size is too small
                        h1 = h
                    if w1 < 1:
                        w1 = w
                    examples_features_scaled = F.interpolate(examples_features, size=(h1,w1),mode='bilinear')  
                    features_scaled =    F.conv2d(F.pad(image_features, ((int(w1/2)), int((w1-1)/2), int(h1/2), int((h1-1)/2))),
                    examples_features_scaled)
                    features_scaled = features_scaled.permute([1,0,2,3])
                    combined = torch.cat((combined,features_scaled),dim=1)
            if cnter == 0:
                Combined = 1.0 * combined
            else:
                if Combined.shape[2] != combined.shape[2] or Combined.shape[3] != combined.shape[3]:
                    combined = F.interpolate(combined, size=(Combined.shape[2],Combined.shape[3]),mode='bilinear')
                Combined = torch.cat((Combined,combined),dim=1)
            cnter += 1
        if ix == 0:
            All_feat = 1.0 * Combined.unsqueeze(0)
        else:
            All_feat = torch.cat((All_feat,Combined.unsqueeze(0)),dim=0)
    return All_feat

Exemplar boxes

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/utils.py#L21-L42

def select_exemplar_rois(image):
    all_rois = []


    print("Press 'q' or Esc to quit. Press 'n' and then use mouse drag to draw a new examplar, 'space' to save.")
    while True:
        key = cv2.waitKey(1) & 0xFF
        if key == 27 or key == ord('q'):
            break
        elif key == ord('n') or key == '\r':
            rect = cv2.selectROI("image", image, False, False)
            x1 = rect[0]
            y1 = rect[1]
            x2 = x1 + rect[2] - 1
            y2 = y1 + rect[3] - 1


            all_rois.append([y1, x1, y2, x2])
            for rect in all_rois:
                y1, x1, y2, x2 = rect
                cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
            print("Press q or Esc to quit. Press 'n' and then use mouse drag to draw a new examplar")


    return all_rois

Feature Correlation + Density Prediction + Adaptation

https://github.com/cvlab-stonybrook/LearningToCountEverything/blob/00ea1888c3a7c495ae06db0bddc3b90b7db8d52f/demo.py#L104-L132

with torch.no_grad():
    features = extract_features(resnet50_conv, image.unsqueeze(0), boxes.unsqueeze(0), MAPS, Scales)


if not args.adapt:
    with torch.no_grad(): output = regressor(features)
else:
    features.required_grad = True
    #adapted_regressor = copy.deepcopy(regressor)
    adapted_regressor = regressor
    adapted_regressor.train()
    optimizer = optim.Adam(adapted_regressor.parameters(), lr=args.learning_rate)


    pbar = tqdm(range(args.gradient_steps))
    for step in pbar:
        optimizer.zero_grad()
        output = adapted_regressor(features)
        lCount = args.weight_mincount * MincountLoss(output, boxes, use_gpu=use_gpu)
        lPerturbation = args.weight_perturbation * PerturbationLoss(output, boxes, sigma=8, use_gpu=use_gpu)
        Loss = lCount + lPerturbation
        # loss can become zero in some cases, where loss is a 0 valued scalar and not a tensor
        # So Perform gradient descent only for non zero cases
        if torch.is_tensor(Loss):
            Loss.backward()
            optimizer.step()


        pbar.set_description('Adaptation step: {:<3}, loss: {}, predicted-count: {:6.1f}'.format(step, Loss.item(), output.sum().item()))


    features.required_grad = False
    output = adapted_regressor(features)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment