Skip to content

Instantly share code, notes, and snippets.

@Joshuaalbert
Created November 8, 2021 17:43
Show Gist options
  • Save Joshuaalbert/725c229c85c7885b02dc32d581878536 to your computer and use it in GitHub Desktop.
Save Joshuaalbert/725c229c85c7885b02dc32d581878536 to your computer and use it in GitHub Desktop.
Fills in empty voxels
def fill_in_empty_cells(voxels, length_scale_voxels=3, support=9, zero_threshold=1e-5):
"""
Fill in zero-values (or values less than zero_threshold) with smoothed values.
Leave the non-zero bins as they are.
Args:
voxels: [batch, voxels_per_dimension, voxels_per_dimension, voxels_per_dimension, num_properties]
support: float, length scale for exponential kernel how "near" in pixels to interpolate.
support: int, how big to make the kernel, should be big enough that there are no regions of this size without a value.
Returns:
voxels with no zero values [batch, voxels_per_dimension, voxels_per_dimension, voxels_per_dimension, num_properties]
"""
# normalised filter position
x = tf.range(-(support//2), (support//2)+1, 1) / length_scale_voxels
X,Y,Z = tf.meshgrid(x, x, x, indexing='ij')
R2 = X**2 + Y**2 + Z**2
log_filter = -0.5*R2
log_filter_sum = tf.reduce_logsumexp(log_filter)
log_filter_normalised = log_filter - log_filter_sum
filter = tf.math.exp(log_filter_normalised)[:,:,:,None, None]# need to be [W,H,D,1,1]
# filter = tf.ones((3, 3, 3, 1, 1)) / (3. * 3. * 3.) # flat filter
# [batch, voxels_per_dimension, voxels_per_dimension, voxels_per_dimension, num_properties]
smoothed_voxels = tf.nn.conv3d(voxels, filters=filter, strides=[1, 1, 1, 1, 1], padding='SAME')
voxels = tf.where(tf.math.abs(voxels)<zero_threshold, smoothed_voxels, voxels)
return voxels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment