Last active July 15, 2024 10:31
import time
import math
import torch
import taichi as ti
import taichi.math as tm
import tinytex as ttex
from tinycio import fsio
from .util import *
from .logger import Logging
from .texture import Texture2D, WrapMode
class RadianceCascades:
cascades = []
c0_interval_length = 12
max_theta = 256
probe_h_offset = 0.1
depth_bias = 3.5
height_scale_factor = 5.
probe_pad = 1
n_theta_n = 128
n_phi_n = 32
def __init__(self, canvas_shape:tuple, kernel_shape:tuple, env_map:torch.Tensor=None, repeat:bool=False):
self.canvas_shape = canvas_shape
self.kernel_shape = kernel_shape
self.repeat = int(repeat)
self.wrap_mode = WrapMode.REPEAT if repeat else WrapMode.CLAMP
self.log = Logging.get_logger()
self.time = time.time()
if env_map is None:
env_fp = '../../data/scene/carpentry_shop_02_1k.exr'
env_map = fsio.load_image(env_fp, graphics_format=fsio.GraphicsFormat.SFLOAT32)
self.log.debug('-------- RC INIT --------')
self._log_timed('Computed params')
env_shape = (self.cascades[self.n_cascades-1]['n_theta'], self.cascades[self.n_cascades-1]['n_phi'])
self.env_map = Texture2D(ttex.Resampling.resize(env_map, shape=env_shape))
self._log_timed('Set up IBL')
def compute(self,
H, W = self.canvas_shape
device = base_color.device
if emissive is None: emissive = torch.zeros(3, H, W).to(device)
self.bounces = bounces
im_h, im_w = height.shape[1:]
base_color = self._size_to_canvas(base_color).to(device)
height = self._size_to_canvas(height).to(device)
normal = self._size_to_canvas(normal).to(device)
emissive = self._size_to_canvas(emissive).to(device)
if unorm: normal = normal * 2. - 1.
if opengl_normals: normal[1:2] * -1
dim_scale = H / im_h if im_h > im_w else W / im_w
height *= dim_scale * self.height_scale_factor
total_time = time.time()
self.log.debug('------ RC COMPUTE -------')
self._log_timed(f'Computing {H} x {W} on device: {device}')
# height remains denormalized
pos =[meshgrid_2d(H, W).permute(0, 3, 1, 2).squeeze(0).to(device) * 0.5 + 0.5, height], dim=0)
self.positions = Texture2D(pos, wrap_mode=self.wrap_mode); self._log_timed('Populated positions')
self.base_color = Texture2D(base_color, wrap_mode=self.wrap_mode); self._log_timed('Populated base color')
self.radiance = Texture2D(emissive, wrap_mode=self.wrap_mode); self._log_timed('Populated radiance')
self.emissive = Texture2D(emissive, wrap_mode=self.wrap_mode); self._log_timed('Populated emissive')
self.normal = Texture2D(normal, wrap_mode=self.wrap_mode); self._log_timed('Populated normals')
c0_probe_height = self.cascades[0]['probe_shape'][0]
c0_probe_width = self.cascades[0]['probe_shape'][0]
c0_probe_halfheight = c0_probe_height // 2
c0_probe_halfwidth = c0_probe_width // 2
for i in range(self.bounces):
for cx in range(self.canvas_shape[1]//self.kernel_shape[1]):
for cy in range(self.canvas_shape[0]//self.kernel_shape[0]):
kernel = tm.ivec2(cx, cy)
self.log.debug(f'[RC] Chunk [{cx},{cy}]')
self._gather(quadruple=False, kernel=kernel); self._log_timed('Completed gather')
self._merge(kernel=kernel); self._log_timed('Merged cascades')
self._integrate(kernel=kernel); self._log_timed('Integrated cascades')'[RC] Done. TOTAL time elapsed: {time.time() - total_time:.4f}')
def _size_to_canvas(self, im:torch.Tensor) -> torch.Tensor:
im = im.clone()
im_height, im_width = im.shape[1:]
canvas_height, canvas_width = self.canvas_shape
im_aspect = im_width / im_height
canvas_aspect = canvas_width / canvas_height
aspect_ratio_canvas = canvas_width / canvas_height
if im_aspect > canvas_aspect:
target_width = canvas_width
target_height = int(canvas_width / im_aspect)
target_height = canvas_height
target_width = int(canvas_height * im_aspect)
im = ttex.Resampling.resize(im, (target_height, target_width))
im = ttex.Resampling.tile(im, self.canvas_shape)
return im
def _log_timed(self, msg:str) -> bool:
self.log.debug(f'[RC][{(time.time() - self.time):.4f}] {msg}')
self.time = time.time()
return True
def _compute_params(self):
shape = self.kernel_shape
self.c0_interval_length = max((math.sqrt(shape[0]*shape[1]) / (float(1 << 2 * (self.n_cascades - 1)))), self.c0_interval_length)
for n in range(self.n_cascades):
t1 = self.c0_interval_length
p = {}
p['grid_shape'] = (
int(shape[0] / self.c0_probe_size // math.pow(2, n)) + self.probe_pad * 2,
int(shape[1] / self.c0_probe_size // math.pow(2, n)) + self.probe_pad * 2)
p['probe_shape'] = (
int(shape[0] // (p['grid_shape'][0] - self.probe_pad * 2)),
int(shape[1] // (p['grid_shape'][1] - self.probe_pad * 2)))
p['n_probes'] = int(p['grid_shape'][0] * p['grid_shape'][1])
p['t_min'] = 0.0 if n == 0 else t1 * float(1 << 2 * (n - 1))
p['t_max'] = t1 * float(1 << 2 * n)
p['n_phi'] = int(self.c0_directions * math.pow(2, n))
p['n_theta'] = int(min(p['n_phi']//2, self.max_theta))
p['t_min_arclength'] = ((p['t_min'] * 2) * math.pi / p['n_phi'])
p['t_max_arclength'] = ((p['t_max'] * 2) * math.pi / p['n_phi'])
p['t_min_chord'] = max(2. * p['t_min'] * math.sin(p['t_min_arclength'] / max(2 * p['t_min'], 1)), 1.)
p['t_max_chord'] = max(2. * p['t_max'] * math.sin(p['t_max_arclength'] / max(2 * p['t_max'], 1)), 1.)
p['n_steps'] = max(int(min(math.floor((p['t_max'] - p['t_min']) / (p['t_min_chord'])), p['t_max'] - p['t_min'])), 1)
p['step_size'] = float(p['t_max'] - p['t_min']) / float(p['n_steps'])
p['map_tex_res'] = (int(p['n_phi'] * p['grid_shape'][0]), int(p['n_theta'] * p['grid_shape'][1]))
p['radiance'] = ti.field(tm.vec4, shape=(p['n_probes'], p['n_theta'], p['n_phi']))
p['occ_vec'] = ti.types.vector(int(p['n_theta'] / 32) + 1, dtype=ti.u32)
# occ_vec bitfield is excessive for heightfield occlusion, but leaves the door open to changes in future
def info(self, printout=True):
info = ''
for n, cascade in enumerate(self.cascades):
info += '--------------------\n'
info += 'C ' + str(n) + '\n'
for param in cascade:
if param in ['radiance', 'occ_vec']: continue
info += param + ' ' + str(cascade[param]) + '\n'
if printout: print(info)
else: return info
def _phi_to_k(self, phi:float, n_phi:float) -> int:
if (n_phi > self.c0_directions): phi += tm.pi / n_phi
# we need the 0.5 to 'wrap around' in lieu of a round here
return int(tm.round(((phi * n_phi) / (tm.pi * 2)) % (n_phi - 0.5)))
def _theta_to_j(self, theta:float, n_theta:float) -> int:
return int(tm.round((theta * (n_theta - 1.)) / tm.pi))
def _k_to_phi(self, k:float, n_phi:float) -> float:
phi = k * ((tm.pi * 2) / n_phi)
if (n_phi > self.c0_directions): phi -= tm.pi / n_phi
return phi % (tm.pi * 2)
def _j_to_theta(self, j:float, n_theta:float) -> float:
return j * (tm.pi / (n_theta - 1))
def _jk_to_vec(self, j:int, k:int, n_theta:int, n_phi:int) -> tm.vec3:
theta_h = self._j_to_theta(j, n_theta)
phi_h = self._k_to_phi(k, n_phi)
return tm.vec3(tm.sin(theta_h) * tm.cos(phi_h), tm.sin(theta_h) * tm.sin(phi_h), tm.cos(theta_h))
def _log_polar(self, center:tm.vec2, rho:float, phi:float) -> tm.vec2:
phi = phi - tm.pi
return tm.vec2(center.x + rho * tm.cos(phi), center.y + rho * tm.sin(phi))
def _compute_j(self, euclidian_distance:float, height_difference:float, n_theta:int) -> int:
theta_radians = tm.clamp(tm.atan2(-height_difference, euclidian_distance) + (tm.pi * 0.5), 0., tm.pi)
return self._theta_to_j(theta_radians, n_theta)
def _downsample_cascade(self, field:ti.template(), p:int, n_theta:int, n_phi:int, j:int, k:int) -> tm.vec4:
s0 = field[p, (j + 0) % n_theta, tm.clamp((k + 0), 0, n_phi)]
s1 = field[p, (j + 1) % n_theta, tm.clamp((k + 0), 0, n_phi)]
s2 = field[p, (j + 0) % n_theta, tm.clamp((k + 1), 0, n_phi)]
s3 = field[p, (j + 1) % n_theta, tm.clamp((k + 1), 0, n_phi)]
return (s0 + s1 + s2 + s3) / 4.
def _bilateral_interp_coeffs(self,
kernel:tm.ivec2) -> (tm.ivec4, tm.vec4):
"""Returns pixel indices and weights for interpolation"""
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
hp = 1. / (canvas_shape * 2.)
padded_kernel = tm.vec2(kernel_shape.x + self.probe_pad * probe_width * 2, kernel_shape.y + self.probe_pad * probe_height * 2)
xy_grid = tm.vec2(xy_grid.x + self.probe_pad * probe_width, xy_grid.y + self.probe_pad * probe_height)
uv_grid = xy_grid / padded_kernel
pos = tm.vec2(grid_width * uv_grid.x - 0.5, grid_height * uv_grid.y - 0.5)
x0, y0, x1, y1 = int(tm.floor(pos.x)), int(tm.floor(pos.y)), 0, 0
x0, y0 = tm.clamp(x0, 0, grid_width-1), tm.clamp(y0, 0, grid_height-1)
x1, y1 = tm.min(x0+1, grid_width-1), tm.min(y0+1, grid_height-1)
dx = (pos.x + 1.) - (float(x0) + 1.)
dy = (pos.y + 1.) - (float(y0) + 1.)
q00 = int((y0 * grid_width) + x0)
q01 = int((y1 * grid_width) + x0)
q10 = int((y0 * grid_width) + x1)
q11 = int((y1 * grid_width) + x1)
indices = tm.ivec4(q00, q01, q10, q11)
depth_bias = self.depth_bias
co = tm.vec2(kernel.x * kernel_shape.x - probe_width * self.probe_pad, kernel.y * kernel_shape.y - probe_height * self.probe_pad)
uv00 = tm.vec2(co.x + x0 * probe_width + probe_halfwidth, co.y + y0 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv01 = tm.vec2(co.x + x0 * probe_width + probe_halfwidth, co.y + y1 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv10 = tm.vec2(co.x + x1 * probe_width + probe_halfwidth, co.y + y0 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv11 = tm.vec2(co.x + x1 * probe_width + probe_halfwidth, co.y + y1 * probe_height + probe_halfheight) / canvas_shape.xy + hp
d00 = self.positions.sample_bilinear(uv00).z
d01 = self.positions.sample_bilinear(uv01).z
d10 = self.positions.sample_bilinear(uv10).z
d11 = self.positions.sample_bilinear(uv11).z
uv_pos = tm.vec3(co + uv_grid, probe_elevation)
weights = tm.vec4(0.);
weights[0] = (1. - dx) * (1. - dy)
weights[1] = (1. - dx) * dy
weights[2] = dx * (1. - dy)
weights[3] = dx * dy
weights[0] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv00 * canvas_shape, d00)) * 0.01) * depth_bias)
weights[1] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv01 * canvas_shape, d01)) * 0.01) * depth_bias)
weights[2] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv10 * canvas_shape, d10)) * 0.01) * depth_bias)
weights[3] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv11 * canvas_shape, d11)) * 0.01) * depth_bias)
weights /=, weights)
return indices, weights
def _gather(self, quadruple:bool, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
twopi = 2 * tm.pi
for i in ti.static(range(self.n_cascades)):
n_probes = self.cascades[i]['n_probes']
t_min = self.cascades[i]['t_min']
t_max = self.cascades[i]['t_max']
n_phi = self.cascades[i]['n_phi']
n_theta = self.cascades[i]['n_theta']
n_steps = self.cascades[i]['n_steps']
step_size = self.cascades[i]['step_size']
grid_height = self.cascades[i]['grid_shape'][0]
grid_width = self.cascades[i]['grid_shape'][1]
probe_height = self.cascades[i]['probe_shape'][0]
probe_width = self.cascades[i]['probe_shape'][1]
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
n_taps = int(1 << i if quadruple else 1)
hp = 1. / (canvas_shape * 2.)
for p, k in ti.ndrange(n_probes, n_phi):
for tap in range(n_taps):
row = int(p / grid_width)
col = int(p % grid_width)
x = kernel.x * kernel_shape.x - probe_width * self.probe_pad + (float(col) * probe_width) + probe_halfwidth
y = kernel.y * kernel_shape.y - probe_height * self.probe_pad + (float(row) * probe_height) + probe_halfheight
xy_grid = tm.vec2(x, y)
z_probe = self.positions.sample_bilinear(xy_grid / canvas_shape + hp).z
occlusion = self.cascades[i]['occ_vec'](0.)
phi = self._k_to_phi(k, n_phi)
if i > 0: phi += tm.pi / float(n_theta)
if tap > 0: phi = (phi + (twopi / (n_phi * n_taps) * tap))
phi = phi % twopi
for s in range(n_steps):
rho = tm.max(t_min + step_size * s, 1.)
uv = self._log_polar(xy_grid, rho, phi) / canvas_shape + hp
if (self.repeat == 0) and (uv.x < 0. or uv.y < 0. or uv.x > 1. or uv.y > 1.): break
z_sample = self.positions.sample_bilinear(uv).z
sj = n_theta - self._compute_j(rho, (z_probe + self.probe_h_offset) - z_sample, n_theta)
for j in range(sj, n_theta):
occ_idx, occ_val = int(j / 32.), 1 << int(j % 32)
if occlusion[occ_idx] & occ_val: break
result = tm.vec4(self.radiance.sample_bilinear(uv), 1.) / n_taps
result += self.cascades[i]['radiance'][p, j, k]
self.cascades[i]['radiance'][p, j, k] = result
occlusion[occ_idx] |= occ_val
return True
def _merge(self, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
for i in ti.static(range(self.n_cascades-1, -1, -1)):
n_probes = self.cascades[i]['n_probes']
n_phi = self.cascades[i]['n_phi']
n_theta = self.cascades[i]['n_theta']
grid_height = self.cascades[i]['grid_shape'][0]
grid_width = self.cascades[i]['grid_shape'][1]
probe_height = self.cascades[i]['probe_shape'][0]
probe_width = self.cascades[i]['probe_shape'][1]
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
for p, k in ti.ndrange(n_probes, n_phi):
row = int(p / grid_width)
col = int(p % grid_width)
x = (float(col) * probe_width) + probe_halfwidth
y = (float(row) * probe_height) + probe_halfheight
px = int(kernel.x * kernel_shape.x - probe_width * self.probe_pad + x)
py = int(kernel.y * kernel_shape.y - probe_height * self.probe_pad + y)
probe_elevation =[int(py), int(px)].z
xy_grid = tm.vec2(x - probe_width * self.probe_pad, y - probe_height * self.probe_pad)
for j in range(n_theta):
rad_current = self.cascades[i]['radiance'][p, j, k]
if rad_current.a < 1.:
if ti.static(i == self.n_cascades-1):
rad_current.rgb = tm.mix([j, k].rgb, rad_current.rgb, rad_current.a)
self.cascades[i]['radiance'][p, j, k] = tm.vec4(rad_current)
uv_sphere = tm.vec2(float(k) / float(n_phi), float(j) / float(n_theta))
m_grid_height = self.cascades[i+1]['grid_shape'][0]
m_grid_width = self.cascades[i+1]['grid_shape'][1]
m_probe_height = float(self.cascades[i+1]['probe_shape'][0])
m_probe_width = float(self.cascades[i+1]['probe_shape'][1])
indices, weights = self._bilateral_interp_coeffs(
m_n_theta = self.cascades[i+1]['n_theta']
m_n_phi = self.cascades[i+1]['n_phi']
m_j = int(m_n_theta * uv_sphere.y)
m_k = int(m_n_phi * uv_sphere.x)
q00 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[0], m_n_theta, m_n_phi, m_j, m_k) * weights[0]
q01 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[1], m_n_theta, m_n_phi, m_j, m_k) * weights[1]
q10 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[2], m_n_theta, m_n_phi, m_j, m_k) * weights[2]
q11 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[3], m_n_theta, m_n_phi, m_j, m_k) * weights[3]
rad_merge = q00 + q01 + q10 + q11
rad_current.rgb = tm.mix(rad_merge.rgb, rad_current.rgb, rad_current.a)
self.cascades[i]['radiance'][p, j, k] = rad_current
return True
def _integrate(self, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
grid_height = self.cascades[0]['grid_shape'][0]
grid_width = self.cascades[0]['grid_shape'][1]
probe_height = float(self.cascades[0]['probe_shape'][0])
probe_width = float(self.cascades[0]['probe_shape'][1])
probe_halfwidth = int(probe_width // 2)
probe_halfheight = int(probe_height // 2)
n_phi = self.cascades[0]['n_phi']
n_theta = self.cascades[0]['n_theta']
for y, x in ti.ndrange(self.kernel_shape[0], self.kernel_shape[1]):
px = int(kernel.x * kernel_shape.x + x)
py = int(kernel.y * kernel_shape.y + y)
base_color =[py, px]
normal =[py, px]
probe_elevation =[py, px].z
xy_grid = tm.vec2(x, y)
radiance = tm.vec4(0.)
ct = 0
indices, weights = tm.ivec4(0), tm.vec4(0.)
indices, weights = self._bilateral_interp_coeffs(
for j, k in ti.ndrange(n_theta, n_phi):
v = self._jk_to_vec(j, k, n_theta, n_phi)
v.x = -v.x
light = tm.vec4(0.)
q00 = self.cascades[0]['radiance'][indices[0], j, k] * weights[0]
q01 = self.cascades[0]['radiance'][indices[1], j, k] * weights[1]
q10 = self.cascades[0]['radiance'][indices[2], j, k] * weights[2]
q11 = self.cascades[0]['radiance'][indices[3], j, k] * weights[3]
light = q00 + q01 + q10 + q11
dot_nl =, normal)
theta = self._j_to_theta(j, n_theta)
radiance += light * tm.sin(theta) * ((2 * tm.pi) / (n_phi * n_theta)) * dot_nl if dot_nl > 0. else 0.[py, px] =[py, px] + (base_color/tm.pi * radiance.rgb )
return True
# Example:
# rc = tr.RadianceCascades(canvas_shape=(1024,1024), kernel_shape=(64,64), env_map=env_tensor)
# rc.compute(base_color=base_color, height=height, normal=normal, emissive=emissive, bounces=1)
#, 0, 1)
