Created
April 26, 2024 09:51
-
-
Save xvdp/8f84d74416bab23504ef0ff8b457064f to your computer and use it in GitHub Desktop.
morton codes in torch / python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Morton codes should never be done in python, but in cpp, as they require looping | |
Still, to understand them here are 2 different implementations/ They both assume 3d points of shape (...,3] | |
1. More general computes a pyramid, using np or torch | |
>> get_morton_codes() | |
>> get_points_from_morton() | |
2. easier to read, uses morton to compute knn | |
>> knn() | |
Sources and further reading, morton code based voxelization | |
/home/z/work/gits/RF/nglod/sdf-net/lib/spc3d.py | |
def to_morton(p): | |
https://github.com/nv-tlabs/nglod/sdf-net/app/spc/spc_utils.py | |
def create_dense_octree(level): | |
# Dividing by 2 will yield the morton code of the parent | |
pc = torch.floor(points / 2.0).short() | |
/home/z/work/gits/RF/nglod/sdf-net/lib/spc3d.py | |
def to_morton(p): | |
https://github.com/nv-tlabs/nglod/sol-renderer/include/spc/spc/spc_math.h | |
static __inline__ __host__ __device__ morton_code ToMorton(point_data V) | |
https://github.com/nv-tlabs/nglod/sol-renderer/include/spc/spc/SPC.cu | |
__global__ void MortonToPoint( | |
""" | |
import numpy as np | |
import torch | |
# pylint: disable=no-member | |
##### 1 Morton Pyramid: | |
def get_morton_codes(p, levels=16): | |
""" computes morton codes | |
torch float32 (N,3) -> int64 (N) | |
numpy float32 (N,3) -> uint64(N) | |
""" | |
assert levels <= 16 | |
return to_morton(quantize(p), levels=levels) | |
def get_points_from_morton(mcode, mmin=0, mmax=1.0, levels=16): | |
""" computes points from morton codes | |
torch int64 (N) -> float32 (N,3) | |
numpy uint64(N) -> float32 (N,3) | |
""" | |
assert levels <= 16 | |
return to_float(from_morton(mcode, levels=levels), mmin, mmax) | |
def to_morton(x, levels=16): | |
""" torch -> int64, np -> uint64 | |
""" | |
if torch.is_tensor(x): | |
x = x.to(dtype=torch.int64) | |
mcode = torch.zeros((len(x)), dtype=torch.int64, device=x.device) | |
else: | |
x = x.astype(np.uint64) | |
mcode = np.zeros((len(x)), dtype=np.uint64) | |
for i in range(levels): | |
i2 = i*2 | |
mcode |= (x[...,2] & (0x1 << i)) << i2 | |
mcode |= (x[...,1] & (0x1 << i)) << i2+1 | |
mcode |= (x[...,0] & (0x1 << i)) << i2+2 | |
return mcode | |
def quantize(p): | |
""" [-float, +float] -> [0, 1<<16-1] | |
torch -> int32, np -> uint16 | |
""" | |
mmin = p.min(axis=0) | |
mmax = p.max(axis=0) | |
if torch.is_tensor(p): | |
mmin = mmin[0] | |
mmax = mmax[0] | |
out = (p - mmin)/(mmax - mmin) *(1 << 16 -1) | |
if torch.is_tensor(p): | |
return out.to(dtype=torch.int32) | |
return out.astype(np.uint16) | |
def to_float(p, mmin, mmax): | |
""" inverse of quantize | |
""" | |
if torch.is_tensor(p): | |
return p.to(dtype=torch.float32)/(1<<16 -1) * (mmax - mmin) + mmin | |
return p.astype(dtype=np.float32)/(1<<16 -1) * (mmax - mmin) + mmin | |
def from_morton(mcode, levels=16): | |
""" mcode to quantized points | |
torch int64 (N) -> int32 (N,3) | |
numpy uint64(N) -> uint16 (N,3) | |
""" | |
if torch.is_tensor(p): | |
p = torch.zeros((len(mcode), 3), dtype=torch.int32, device=mcode.device) | |
else: | |
p = np.zeros((len(mcode), 3), dtype=np.uint16) | |
for i in range(levels): | |
p[..., 0] |= (mcode & (0x1 << (3 * i + 2))) >> (2 * i + 2) | |
p[..., 1] |= (mcode & (0x1 << (3 * i + 1))) >> (2 * i + 1) | |
p[..., 2] |= (mcode & (0x1 << (3 * i + 0))) >> (2 * i + 0) | |
return p | |
##### Implementation 2 | |
def _morton(x): | |
""" where x is int32 | |
x = x + x*(2**16) masked by 11 0000 0000 0000 0000 1111 1111 | |
x = x + x*(2**8) masked by 11 0000 0000 1111 0000 0000 1111 | |
x = x + x*(2**4) masked by 11 0000 11 000 011 0000 11 0000 11 | |
x = x + x*(2**2) masked by 1 00 1 00 1 00 1 00 1 00 1 00 1 00 1 00 1 00 1 | |
0xFFFFFFFF = 1<<32 -1 | |
""" | |
x = (x | (x << 16)) & 0x030000FF | |
x = (x | (x << 8)) & 0x0300F00F | |
x = (x | (x << 4)) & 0x030C30C3 | |
x = (x | (x << 2)) & 0x09249249 | |
return x[:, 0] | (x[:, 1] << 1) | (x[:, 2] << 2) | |
def toint(x, bits=10): | |
if torch.is_tensor(x): | |
dtype = torch.int32 if bits <=16 else torch.int64 | |
return x.to(dtype=dtype) | |
dtype = np.int32 if bits <=16 else np.int64 | |
return x.astype(dtype) | |
def tobits(points, bits=10): | |
mmin = points.min(axis=0) | |
mmax = points.max(axis=0) | |
if torch.is_tensor(points): | |
mmin = mmin[0] | |
mmax = mmax[0] | |
return toint((points - mmin)/(mmax - mmin) * ((1 << bits) - 1), bits) | |
def morton(points, bits=10): | |
""" float32 -> int32 -> morton | |
""" | |
x = tobits(points, bits) | |
return _morton(x) | |
def knn(points, box_size=1024): | |
""" points: shape (N, 3) | |
""" | |
num_boxes = int((points.shape[0] + box_size - 1) / box_size) | |
codes = morton(points) | |
sorted_indices = torch.argsort(codes) | |
sorted_points = points[sorted_indices] | |
# morton ordered points split to boxes | |
split_points = sorted_points.split(box_size) | |
_sect = -1 if len(points)%box_size else None | |
#(num_boxes, num_points, 3) | |
stack_points = torch.stack(split_points[:_sect], axis=0) | |
# bounding boxes | |
mmin = stack_points.min(axis=1)[0] | |
mmax = stack_points.max(axis=1)[0] | |
if _sect is not None: | |
mmin = torch.cat((mmin, split_points[-1].min(axis=0, keepdim=True)[0])) | |
mmax = torch.cat((mmax, split_points[-1].max(axis=0, keepdim=True)[0])) | |
top_kth = reject_dist2(sorted_points, num=6, k=3)# (N) | |
points_to_boxes = dist2_to_box(mmin[:,None], mmax[:,None], sorted_points) # (num_boxes, N) | |
return top_kth, points_to_boxes, points, sorted_indices | |
def reject_dist2(p, num=6, k=3): | |
""" sq distance to top-k item | |
-> (N) | |
Args | |
p (N, dims) | |
num (int [6]) | |
k (int [3]) topk3 -> out[2] | |
finds neighbors in box by rolling - returning items on other end of box at either end | |
""" | |
idx = list(range(math.ceil(num/2), -math.ceil((num+1)/2), -1)) | |
idx.pop(idx.index(0)) | |
neighbors = torch.stack([torch.roll(p, i, dims=0) for i in idx], axis=0) | |
return ((p - neighbors)**2).sum(axis=-1).sort(axis=0)[0][k-1] # N | |
def dist2_to_box(minn, maxx, p): | |
""" square dist point to box boundaries | |
Args | |
minn, maxx vector (3) | (M, 1, 3) # box boundaries | |
p vector (3) | (N, 3) # points | |
""" | |
return ((torch.clamp(p-maxx, min=0) + torch.clamp(minn-p, min=0))**2).sum(axis=-1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment