Created
January 21, 2022 09:49
-
-
Save Ending2015a/b034ebbedc55fec1d8ec3b7230a95f1e to your computer and use it in GitHub Desktop.
Scatter N-D, PyTorch implementation, this function can be used for Active Neural SLAM to project depth maps to top-down height map
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch import nn | |
import torch_scatter # pip install torch-scatter | |
def ravel_index(index, shape): | |
"""Ravel multi-dimensional indices to 1D index | |
similar to np.ravel_multi_index | |
Args: | |
index (torch.tensor): indices in reversed order dn, ..., d1, with shape (..., n) | |
shape (tuple): dn, ..., d1 | |
""" | |
index = torch.tensor(index, dtype=torch.int64) | |
shape = torch.tensor((1,) + shape[::-1], dtype=torch.int64) # =(1, d1, d1*d2, ..., d1*...*dn) | |
shape = torch.cumprod(shape, dim=0)[:-1].flip(0) # =(d1*...*dn-1, ..., d1*d2, d1, 1) | |
index = (index * shape).sum(dim=-1) # (...,) | |
return index | |
def masked_scatter_nd_max(canvas, indices, values, mask=None, fill_value=-np.inf): | |
''' | |
Scatters vector values with dim=v over an n-dim canvas, | |
For projecting to an image-type canvas, `dn` = `d2`, that is, the | |
shape of the canvas is (..., d1, d2, v) or (..., h, w, v), where | |
`v` is the depth of the vector values, i.e. `v`=3 for projecting | |
point cloud to a height map (xyz coords). In this case, `values` | |
is the flattened batch point clouds with `N` points. `n` for | |
`indices` is the number of dimensions of the canvas, which is `n`=2, | |
i.e. (h, w) coords. | |
Args: | |
canvas (tf.Tensor): Canvas with shape (..., d1, ..., dn, v) | |
indices (tf.Tensor): (d1, ..., dn) coordinates, where each value scattered, | |
with shape (..., N, n) | |
values (tf.Tensor): Vector values with shape (..., N, v) | |
mask (tf.Tensor): Mask, where valid area=True, with shape (..., N). | |
Returns: | |
tf.Tensor, updated canvas, (batch, ..., d1, ..., dn, v) | |
tf.Tensor, final masks, (batch, ..., N) | |
''' | |
# default mask | |
if mask is None: | |
mask = torch.ones(values.shape[:-1], dtype=torch.bool) | |
# converts to tensors | |
canvas = torch.tensor(canvas) | |
indices = torch.tensor(indices).to(dtype=torch.int64) | |
values = torch.tensor(values) | |
mask = torch.tensor(mask).to(dtype=torch.bool) | |
# get dimensions | |
n = indices.shape[-1] | |
v = canvas.shape[-1] | |
N = mask.shape[-1] | |
d1_dn = canvas.shape[-n-1:-1] | |
batch_dims = canvas.shape[:-n-1] | |
ind_dtype = indices.dtype | |
# find valid areas | |
valid_areas = [mask] | |
for i in reversed(range(n)): | |
di = indices[..., i] | |
valid_areas.extend(( | |
di < d1_dn[i], | |
di >= 0 | |
)) | |
valid_area = torch.stack(valid_areas, dim=0) | |
mask = valid_area.all(dim=0) | |
# dummy index for invalid values (0, ..., 0, -1) | |
indices[..., :][~mask] = 0 | |
indices[..., -1][~mask] = -1 | |
# flatten canvas, indices, mask | |
flat_canvas = canvas.view(*batch_dims, -1, v) # (..., d1*...*dn, v) | |
flat_indices = ravel_index(indices, d1_dn) # convert n-d indices to 1-d indices (..., N) | |
flat_mask = mask # (..., N) | |
flat_values = values # (..., N, v) | |
# create dummy channel to store invalid values | |
dummy_channel = torch.zeros_like(flat_canvas[..., 0:1, :]) | |
dummy_shift = 1 # shift dummy index from (0, ..., 0, -1) to (0, ..., 0, 0) | |
flat_canvas = torch.cat((dummy_channel, flat_canvas), dim=-2) # (..., 1 + d1*...*dn, v) | |
flat_indices = flat_indices + dummy_shift | |
flat_canvas.fill_(fill_value) | |
torch_scatter.scatter_max(flat_values, flat_indices, dim=-2, out=flat_canvas) | |
flat_canvas = flat_canvas[..., 1:, :] | |
canvas = flat_canvas.view(canvas.shape) | |
mask = torch.isinf(canvas) | |
return canvas, mask | |
# dummy point cloud (batch, channel, height, width, xyz) = (1, 2, 5, 5, 3) | |
# unit: meter | |
values = np.array([[[[[ 0.92871926, -0.39209746, 0.12709531], | |
[ 0.37437783, -0.25560278, -2.1768249 ], | |
[-0.21010604, -0.87326627, -0.60568358], | |
[ 0.11826354, 0.72192535, -1.96805051], | |
[ 0.07642954, 0.02877341, -0.52130058]], | |
[[-1.07883079, -1.09864275, -1.48197995], | |
[-0.52746128, 0.64207189, 0.95996284], | |
[ 0.29431672, -0.79195994, -0.29312353], | |
[-0.58089971, 0.05356699, -0.18195914], | |
[ 0.63448274, -0.64338309, -0.18980063]], | |
[[-0.39415563, -2.61698209, -1.60855244], | |
[-1.85730103, 1.96747892, -1.36135689], | |
[ 0.17008098, 0.69992018, -1.69435467], | |
[-0.42376153, 0.34204736, 0.3173328 ], | |
[ 1.31884528, -1.28284411, -0.06323276]], | |
[[ 1.01415592, -1.56410225, 2.55963775], | |
[-0.1527702 , -1.27259893, 0.97006746], | |
[ 0.46391498, -0.82628582, -1.22322484], | |
[ 0.51598177, -0.90726735, -2.15268906], | |
[ 0.88671569, 0.34563078, 0.54024559]], | |
[[-1.20541569, -0.27154192, -0.05633884], | |
[-0.36523929, -1.17248391, 0.84481116], | |
[-1.03267173, -0.3065308 , -0.35678831], | |
[ 0.92520116, -0.8984506 , -0.58580828], | |
[-0.62473293, -0.74235885, -0.72037534]]], | |
[[[ 0.09297083, 0.98570852, 1.13650902], | |
[ 0.81261274, 0.21577615, -0.80296376], | |
[ 1.39902247, -0.41790638, 0.37105384], | |
[-0.235837 , 1.14946586, -0.46826193], | |
[ 0.89406117, -0.81903676, -1.40690595]], | |
[[-1.13937087, -0.81807408, 0.0697723 ], | |
[-0.0718852 , -0.52776485, -1.79533604], | |
[ 0.56097385, 0.26405042, 0.07248514], | |
[-0.51417208, 1.28195223, -1.60939298], | |
[-0.1779261 , 0.14759517, -0.79710853]], | |
[[ 1.07133254, -0.86649908, 1.10818405], | |
[ 0.51709258, 0.16462324, -0.10645144], | |
[-0.94297979, 0.23160525, -1.00794647], | |
[ 0.05334653, 1.0522464 , -0.6964805 ], | |
[ 0.97591096, -0.2690103 , -1.33586831]], | |
[[ 0.02043337, -1.67731703, 0.72714383], | |
[ 0.7053991 , -0.32442375, -0.41602061], | |
[ 1.01215432, 2.43477928, -0.86891597], | |
[ 1.5247537 , -1.86446265, 0.29876436], | |
[-2.26656319, 1.12710737, 2.89601227]], | |
[[ 0.59182888, 0.29882975, -0.16293282], | |
[-1.09208092, -2.08845169, 2.17915906], | |
[ 0.48356899, 0.22009589, 0.28158253], | |
[ 0.16641354, 0.36653133, 0.49896538], | |
[-0.34871221, -0.56461655, 1.49807201]]]]], dtype=np.float32) + 2. | |
# (1, 2, 5, 5) | |
x = values[..., 0] | |
y = values[..., 1] | |
z = values[..., 2] | |
map_res = 1.0 # map resolutions (unit: meter per cell) | |
map_size = 3 # map cells (map_size by map_size map) | |
# quantize point cloud (1, 2, 5, 5) | |
z_bin = (-z/map_res + (map_size-1)).astype(np.int64) | |
x_bin = (x/map_res + (map_size-1)/2).astype(np.int64) | |
# filter out invalid areas (indices out of the map range) | |
isvalid = np.stack(( | |
z_bin >= 0, z_bin < map_size, x_bin >= 0, x_bin < map_size | |
), axis=0) | |
isvalid = np.all(isvalid, axis=0) # (1, 2, 5, 5) | |
# create empty map (canvas) | |
canvas = np.zeros((1, 2, 3, 3, 3) ,dtype=np.float32) # (1, 2, 3, 3, 3) | |
# combine coordinates | |
indices = np.stack((z_bin, x_bin), axis=-1) # (1, 2, 5, 5, 2) | |
flat_indices = torch.tensor(indices).view(1, 2, 25, 2) | |
flat_values = torch.tensor(values).view(1, 2, 25, 3) | |
# scatter | |
new_canvas, mask = masked_scatter_nd_max(canvas, flat_indices, flat_values) | |
print('z_bin:', z_bin) | |
print('x_bin:', x_bin) | |
print('isvalid:', isvalid) | |
print('mask:', mask) | |
indices[np.logical_not(isvalid)] = [-1, -1] | |
indices = indices.reshape(1, 2, -1, 2) | |
points = np.unique(indices[0, 0], axis=0) | |
print('points:', points) | |
points = np.unique(indices[0, 1], axis=0) | |
print('points:', points) | |
print('y:', y) | |
print('new_canvas:', new_canvas[..., 1]) | |
''' | |
Expecting results: | |
tensor([[[[ -inf, 1.7285, 2.6421], | |
[ -inf, 3.9675, -0.6170], | |
[ -inf, -inf, -inf]], | |
[[ -inf, 1.1819, 3.1495], | |
[ -inf, -inf, 3.2820], | |
[ -inf, -inf, -inf]]]]) | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
TensorFlow 2.0 implementation here: https://gist.github.com/Ending2015a/215375b470dcdd50de3c9b2252337888