Created
June 18, 2019 00:08
-
-
Save smeschke/3cc053d6b370aace7465604eb83d8c5e to your computer and use it in GitHub Desktop.
Applies grabcut using a mask generated with DL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
from matplotlib import pyplot as plt | |
# Load image and mask | |
img = cv2.imread('/home/stephen/Downloads/bird.jpg') | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
h,w,_ = img.shape | |
deep_mask = cv2.imread('/home/stephen/Downloads/bird_mask.png',0) | |
deep_mask = cv2.resize(deep_mask, (w,h)) | |
mask = np.zeros(img.shape[:2],np.uint8) | |
white_background = (255 - mask.copy()) | |
# Initialize parameters for the GrabCut algorithm | |
bgdModel = np.zeros((1,65),np.float64) | |
fgdModel = np.zeros((1,65),np.float64) | |
iters, size = 3,5 | |
kernel = np.ones((size,size),np.uint8) | |
big_kernel = np.ones((2*size,2*size),np.uint8) | |
# Dilate the mask to make sure the whole object is covered by the mask | |
dilation = cv2.dilate(deep_mask, big_kernel, iterations = iters) | |
# Start with a white background and subtract | |
sure_background = white_background - dilation | |
# Erode to find the sure foreground | |
sure_foreground = cv2.erode(deep_mask, kernel, iterations = iters) | |
# Change the values on the mask so that: | |
# 2 - unsure pixels | |
# 1 - sure foreground pixels | |
# 0 - sure background pixels | |
mask[:] = 2 | |
mask[sure_background == 255] = 0 | |
mask[sure_foreground == 255] = 1 | |
# Apply GrabCut | |
out_mask= mask.copy() | |
out_mask, _, _ = cv2.grabCut(img,out_mask,None,bgdModel,fgdModel,3,cv2.GC_INIT_WITH_MASK) | |
out_mask = np.where((out_mask==2)|(out_mask==0),0,1).astype('uint8') | |
out_img = img*out_mask[:,:,np.newaxis] | |
# Plot with Matplotlib | |
import matplotlib.pyplot as plt | |
import matplotlib.image as mpimg | |
f, axarr = plt.subplots(2,3, sharex=True) | |
axarr[0,0].imshow(img) | |
axarr[1,0].imshow(deep_mask) | |
background = img.copy() | |
background[sure_background == 0] = (0,0,0) | |
background = cv2.addWeighted(background, .5, img, .5, 1) | |
axarr[0,1].imshow(background) | |
foreground = img.copy() | |
foreground[sure_foreground == 0] = (0,0,0) | |
foreground = cv2.addWeighted(foreground, .5, img, .5, 1) | |
axarr[1,1].imshow(foreground) | |
axarr[0,2].imshow(out_mask) | |
axarr[1,2].imshow(out_img) | |
axarr[0,0].set_title('Source Image') | |
axarr[1,0].set_title('Mask from DL') | |
axarr[0,1].set_title('Sure Background') | |
axarr[1,1].set_title('Sure Foreground') | |
axarr[0,2].set_title('GrabCut Mask') | |
axarr[1,2].set_title('GrabCut Image') | |
axarr[0,0].axis('off') | |
axarr[0,1].axis('off') | |
axarr[1,0].axis('off') | |
axarr[1,1].axis('off') | |
axarr[1,2].axis('off') | |
axarr[0,2].axis('off') | |
plt.show() |
Author
smeschke
commented
Jun 18, 2019
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment