Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Created April 23, 2022 17:47
Show Gist options
  • Save Mason-McGough/29c7535d7bee24f5005c2e554239e4de to your computer and use it in GitHub Desktop.
Save Mason-McGough/29c7535d7bee24f5005c2e554239e4de to your computer and use it in GitHub Desktop.
Stratified sampling for NeRF
def sample_stratified(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
near: float,
far: float,
n_samples: int,
perturb: Optional[bool] = True,
inverse_depth: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Sample along ray from regularly-spaced bins.
"""
# Grab samples for space integration along ray
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
if not inverse_depth:
# Sample linearly between `near` and `far`
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
# Sample linearly in inverse depth (disparity)
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
# Draw uniform samples from bins along ray
if perturb:
mids = .5 * (z_vals[1:] + z_vals[:-1])
upper = torch.concat([mids, z_vals[-1:]], dim=-1)
lower = torch.concat([z_vals[:1], mids], dim=-1)
t_rand = torch.rand([n_samples], device=z_vals.device)
z_vals = lower + (upper - lower) * t_rand
z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])
# Apply scale from `rays_d` and offset from `rays_o` to samples
# pts: (width, height, n_samples, 3)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
return pts, z_vals
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment