Skip to content

Instantly share code, notes, and snippets.

Created April 23, 2022 17:50
Show Gist options
  • Save Mason-McGough/f96c2704562593c5af2d79cae320a2e2 to your computer and use it in GitHub Desktop.
Save Mason-McGough/f96c2704562593c5af2d79cae320a2e2 to your computer and use it in GitHub Desktop.
Hierarchical sampling for NeRF
def sample_hierarchical(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
z_vals: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Apply hierarchical sampling to the rays.
# Draw samples from PDF using z_vals as bins and weights as probabilities.
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples,
new_z_samples = new_z_samples.detach()
# Resample points from ray based on PDF.
z_vals_combined, _ = torch.sort([z_vals, new_z_samples], dim=-1), dim=-1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None] # [N_rays, N_samples + n_samples, 3]
return pts, z_vals_combined, new_z_samples
def sample_pdf(
bins: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = False
) -> torch.Tensor:
Apply inverse transform sampling to a weighted set of points.
# Normalize weights to get PDF.
pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]]
# Convert PDF to CDF.
cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # [n_rays, weights.shape[-1] + 1]
# Take sample positions to grab from CDF. Linear when perturb == 0.
if not perturb:
u = torch.linspace(0., 1., n_samples, device=cdf.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # [n_rays, n_samples]
# Find indices along CDF where values in u would be placed.
u = u.contiguous() # Returns contiguous tensor with same values.
inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]
# Clamp indices that are out of bounds.
below = torch.clamp(inds - 1, min=0)
above = torch.clamp(inds, max=cdf.shape[-1] - 1)
inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]
# Sample from cdf and the corresponding bin centers.
matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,
bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1,
# Convert samples to ray length.
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples # [n_rays, n_samples]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment