Last active
September 11, 2016 23:21
-
-
Save scott-gray/5a3cd70465dcd2fe1df1 to your computer and use it in GitHub Desktop.
Custom pooling kernels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import numpy as np | |
import pycuda.driver as drv | |
from pycuda.tools import context_dependent_memoize | |
from pycuda.compiler import SourceModule | |
class GaussianPool(object): | |
def __init__(self, | |
N, C, H, W, R, S, | |
stride_h, stride_w, | |
var_y, var_x, | |
mean_y, mean_x): | |
P = _ceil_div(H - R + 1, stride_h) | |
Q = _ceil_div(W - S + 1, stride_w) | |
self.N = N | |
self.C = C | |
self.K = C | |
self.H = H | |
self.W = W | |
self.R = R | |
self.S = S | |
self.P = P | |
self.Q = Q | |
self.str_h = stride_h | |
self.str_w = stride_w | |
self.var_y = var_y | |
self.var_x = var_x | |
self.mean_y = mean_y | |
self.mean_x = mean_x | |
self.dimI = (C,H,W,N) | |
self.dimO = (C,P,Q,N) | |
return P, Q | |
class GaussianPoolGPU(GaussianPool): | |
def __init__(self, | |
N, C, H, W, R, S, | |
stride_h, stride_w, | |
var_y, var_x, | |
mean_y, mean_x): | |
P, Q = super(GaussianPoolGPU, self).__init__( | |
N, C, H, W, R, S, stride_h, stride_w, | |
var_y, var_x, mean_y, mean_x) | |
magic_S = _magic32(R*S + 32, S) | |
magic_str_h = _magic32(H + R, str_h) | |
magic_str_w = _magic32(W + S, str_w) | |
self.fprop_args = [(Q, P, C), (N, 1, 1), _flatten([ | |
var_y, var_x, mean_y, mean_x, | |
Q, N, Q*N, P*Q*N, H, W, W*N, H*W*N, R, S, R*S, | |
magic_S, stride_h, stride_w ])] | |
self.bprop_args = [(W, H, C), (N, 1, 1), _flatten([ | |
var_y, var_x, mean_y, mean_x, | |
P, Q, N, Q*N, P*Q*N, W, W*N, H*W*N, R, S, R*S, | |
magic_S, stride_h, stride_w, magic_str_h, magic_str_w ])] | |
lut_size = R*S | |
if lut_size % 4 != 0: | |
lut_size += 4 - lut_size % 4 | |
self.shared_size = lut_size * 4 * 2 | |
def fprop(self, I, O, alpha=1.0, beta=0.0): | |
args = self.fprop_args | |
params = [args[0], args[1], O.gpudata, I.gpudata, alpha, beta] + args[2] | |
kernel = _get_fprop_kernel() | |
kernel.prepared_call(*params, shared_size=self.shared_size) | |
def bprop(self, I, O, alpha=1.0, beta=0.0): | |
args = self.bprop_args | |
params = [args[0], args[1], O.gpudata, I.gpudata, alpha, beta] + args[2] | |
kernel = _get_bprop_kernel() | |
kernel.prepared_call(*params, shared_size=self.shared_size) | |
class GaussianPoolCPU(GaussianPool): | |
def __init__(self, | |
N, C, H, W, R, S, | |
stride_h, stride_w, | |
var_y, var_x, | |
mean_y, mean_x): | |
super(GaussianPoolCPU, self).__init__( | |
N, C, H, W, R, S, stride_h, stride_w, | |
var_y, var_x, mean_y, mean_x) | |
kernel = np.array([ | |
self.guassian(r,s) | |
for r in range(R) | |
for s in range(S)], dtype=np.float32) | |
kernel = np.sqrt(kernel) / np.sqrt(np.sum(kernel)) | |
# print kernel.reshape((R,S)) | |
# print np.sum(np.square(kernel.reshape((R,S)))) | |
self.filter = kernel.reshape((1, -1, 1)) | |
def guassian(self, r, s): | |
fy = float(r - self.R//2) - self.mean_y; | |
fx = float(s - self.S//2) - self.mean_x; | |
return np.exp(-(self.var_y*0.5*fy*fy + self.var_x*0.5*fx*fx), dtype=np.float32) | |
def fpool_slice(self, q, S, W, stride): | |
qs = q*stride | |
sliceI = [] | |
sliceF = [] | |
for s in range(S): | |
x = qs + s | |
if x >= 0 and x < W: | |
sliceI.append(x) | |
sliceF.append(s) | |
return sliceI, sliceF | |
def fprop(self, I, O, alpha=1.0, beta=0.0): | |
slicableI = I.reshape((self.C, -1, self.N)) | |
slicableF = self.filter | |
W = self.W | |
S = self.S | |
O *= beta | |
for p in range(self.P): | |
sliceY, sliceR = self.fpool_slice(p, self.R, self.H, self.str_h) | |
for q in range(self.Q): | |
sliceX, sliceS = self.fpool_slice(q, self.S, self.W, self.str_w) | |
sliceI = np.array([ | |
y*W + x | |
for y in sliceY | |
for x in sliceX], dtype=np.intp) | |
sliceF = np.array([ | |
r*S + s | |
for r in sliceR | |
for s in sliceS], dtype=np.intp) | |
O[:,p,q,:] += np.sum(slicableI[:,sliceI,:] * slicableF[:,sliceF,:], axis=1) * alpha | |
def bpool_slice(self, x, S, Q, stride): | |
qs = x - (S - 1) | |
sliceI = [] | |
sliceF = [] | |
for s in range(S): | |
q = qs + s | |
if q % stride == 0: | |
q //= stride | |
if q >= 0 and q < Q: | |
sliceI.append(q) | |
sliceF.append(S - s - 1) | |
return sliceI, sliceF | |
def bprop(self, I, O, alpha=1.0, beta=0.0): | |
slicableI = I.reshape((self.K, -1, self.N)) | |
slicableF = self.filter | |
Q = self.Q | |
S = self.S | |
O *= beta | |
for y in range(self.H): | |
sliceP, sliceR = self.bpool_slice(y, self.R, self.P, self.str_h) | |
for x in range(self.W): | |
sliceQ, sliceS = self.bpool_slice(x, self.S, self.Q, self.str_w) | |
sliceI = np.array([ | |
p*Q + q | |
for p in sliceP | |
for q in sliceQ], dtype=np.intp) | |
sliceF = np.array([ | |
r*S + s | |
for r in sliceR | |
for s in sliceS], dtype=np.intp) | |
O[:,y,x,:] += np.sum(slicableI[:,sliceI,:] * slicableF[:,sliceF,:], axis=1) * alpha | |
def _ceil_div(x, y): | |
return -(-x // y) | |
# Magic numbers and shift amounts for integer division | |
def _magic32(nmax, d): | |
nc = ((nmax + 1) // d) * d - 1 | |
nbits = len(bin(nmax)) - 2 | |
for p in range(0, 2 * nbits + 1): | |
if 2 ** p > nc * (d - 1 - (2 ** p - 1) % d): | |
m = (2 ** p + d - 1 - (2 ** p - 1) % d) // d | |
return (m, p) | |
raise ValueError("Can't find magic number for division") | |
# flatten a nested list of lists or values | |
def _flatten(lst): | |
return sum(([x] if not isinstance(x, (list, tuple)) | |
else _flatten(x) for x in lst), []) | |
@context_dependent_memoize | |
def _get_fprop_kernel(): | |
code = r""" | |
union LutEntry { | |
struct { | |
int sliceI; | |
float funcVal; | |
} data; | |
int2 data2; | |
}; | |
__global__ void spool_fprop_guassian( | |
float* O, const float* I, float alpha, float beta, | |
float var_y, float var_x, float mean_y, float mean_x, | |
int Q, int N, int QN, int PQN, int H, int W, int WN, int HWN, | |
int R, int S, int RS, int magic_S, int shift_S, | |
int stride_h, int stride_w) | |
{ | |
float __shared__ rcpSqrtSum; | |
extern __shared__ int2 lut[]; | |
int tid = threadIdx.x; | |
int n = tid; | |
int q = blockIdx.x; | |
int p = blockIdx.y; | |
int k = blockIdx.z; | |
// zigzaq q back and forth to improve L2 cache perf | |
if (p & 1) | |
q = Q - q - 1; | |
I += n; | |
O += k*PQN + p*QN + q*N + n; | |
float O_val = beta > 0.0f ? __ldg(O) : 0.0f; | |
if (tid < 32) | |
{ | |
int pr = p * stride_h; | |
int qs = q * stride_w; | |
int r_half = R >> 1; | |
int s_half = S >> 1; | |
int chan_offset = k * HWN; | |
float var_y2 = var_y * 0.5f; | |
float var_x2 = var_x * 0.5f; | |
float sum = 0.0f; | |
int rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
int x = qs + s; | |
int y = pr + r; | |
LutEntry entry; | |
entry.data.sliceI = chan_offset + y*WN + x*N; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
float val = expf( -(var_y2*fy*fy + var_x2*fx*fx) ); | |
entry.data.funcVal = sqrtf(val); | |
sum += val; | |
lut[rs] = entry.data2; | |
rs += 32; | |
} | |
#pragma unroll | |
for (int i = 16; i > 0; i >>= 1) | |
sum += __shfl_xor(sum, i); | |
rcpSqrtSum = 1.0f / sqrtf(sum); | |
} | |
__syncthreads(); | |
float rcp_sqrt_sum = rcpSqrtSum; | |
int rs = 0; | |
float out = 0.0f; | |
while (rs < RS) | |
{ | |
LutEntry entry0; | |
LutEntry entry1; | |
LutEntry entry2; | |
LutEntry entry3; | |
entry0.data2 = lut[rs + 0]; | |
entry1.data2 = lut[rs + 1]; | |
entry2.data2 = lut[rs + 2]; | |
entry3.data2 = lut[rs + 3]; | |
float val0 = rs + 0 < RS ? __ldg(I + entry0.data.sliceI) : 0.0f; | |
float val1 = rs + 1 < RS ? __ldg(I + entry1.data.sliceI) : 0.0f; | |
float val2 = rs + 2 < RS ? __ldg(I + entry2.data.sliceI) : 0.0f; | |
float val3 = rs + 3 < RS ? __ldg(I + entry3.data.sliceI) : 0.0f; | |
out += val0 * entry0.data.funcVal * rcp_sqrt_sum; | |
out += val1 * entry1.data.funcVal * rcp_sqrt_sum; | |
out += val2 * entry2.data.funcVal * rcp_sqrt_sum; | |
out += val3 * entry3.data.funcVal * rcp_sqrt_sum; | |
rs += 4; | |
} | |
*O = out*alpha + O_val*beta; | |
} | |
""" | |
module = SourceModule(code, options=["--use_fast_math"]) | |
kernel = module.get_function("spool_fprop_guassian") | |
kernel.prepare("PPffffffIIIIIIIIIIIIIII") | |
return kernel | |
@context_dependent_memoize | |
def _get_bprop_kernel(): | |
code = r""" | |
union LutEntry { | |
struct { | |
int sliceI; | |
float funcVal; | |
} data; | |
int2 data2; | |
}; | |
__global__ void spool_bprop_guassian( | |
float* O, const float* I, float alpha, float beta, | |
float var_y, float var_x, float mean_y, float mean_x, | |
int P, int Q, int N, int QN, int PQN, int W, int WN, int HWN, | |
int R, int S, int RS, int magic_S, int shift_S, | |
int stride_h, int stride_w, | |
int magic_stride_h, int shift_stride_h, | |
int magic_stride_w, int shift_stride_w) | |
{ | |
int __shared__ lutSize; | |
extern __shared__ int2 lut[]; | |
int tid = threadIdx.x; | |
int n = tid; | |
int x = blockIdx.x; | |
int y = blockIdx.y; | |
int c = blockIdx.z; | |
// zigzaq q back and forth to improve L2 cache perf | |
if (y & 1) | |
x = W - x - 1; | |
I += n; | |
O += c*HWN + y*WN + x*N + n; | |
float O_val = beta > 0.0f ? __ldg(O) : 0.0f; | |
int lut_size; | |
if (tid < 32) | |
{ | |
int r_half = R >> 1; | |
int s_half = S >> 1; | |
float var_y2 = var_y * 0.5f; | |
float var_x2 = var_x * 0.5f; | |
float sum = 0.0f; | |
int rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
sum += expf( -(var_y2*fy*fy + var_x2*fx*fx) ); | |
rs += 32; | |
} | |
#pragma unroll | |
for (int i = 16; i > 0; i >>= 1) | |
sum += __shfl_xor(sum, i); | |
float rcp_sqrt_sum = 1.0f / sqrtf(sum); | |
int pr = y - (R - 1); | |
int qs = x - (S - 1); | |
int chan_offset = c * PQN; | |
unsigned dep_thd_mask = 0xffffffff; | |
dep_thd_mask >>= 32 - tid; | |
lut_size = 0; | |
rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
int p_prime = pr + r; | |
int q_prime = qs + s; | |
// Invert kernel coordinates | |
r = R - r - 1; | |
s = S - s - 1; | |
// p = p_prime / stride_h | |
// p_mod = p_prime % stride_h | |
int p = p_prime * magic_stride_h; p >>= shift_stride_h; | |
int p_mod = p_prime - p*stride_h; | |
bool p_bounds = p_mod == 0 && p >= 0 && p < P; | |
// q = q_prime / stride_w | |
// q_mod = q_prime % stride_w | |
int q = q_prime * magic_stride_h; q >>= shift_stride_w; | |
int q_mod = q_prime - q*stride_w; | |
bool q_bounds = q_mod == 0 && q >= 0 && q < Q; | |
bool in_bounds = q_bounds && p_bounds; | |
// Get a mask of all valid slices in the warp | |
unsigned ballot = __ballot(in_bounds); | |
// Count the total valid slices | |
unsigned warp_slices = __popc(ballot); | |
if (in_bounds) | |
{ | |
// Count all the valid slices below this threadid | |
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot); | |
LutEntry entry; | |
entry.data.sliceI = chan_offset + p*QN + q*N; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
entry.data.funcVal = sqrtf(expf( -(var_y2*fy*fy + var_x2*fx*fx) )) * rcp_sqrt_sum; | |
lut[lut_size + dep_thd_cnt] = entry.data2; | |
} | |
lut_size += warp_slices; | |
rs += 32; | |
} | |
lutSize = lut_size; | |
} | |
__syncthreads(); | |
lut_size = lutSize; | |
int rs = 0; | |
float out = 0.0f; | |
while (rs < lut_size) | |
{ | |
LutEntry entry0; | |
LutEntry entry1; | |
LutEntry entry2; | |
LutEntry entry3; | |
entry0.data2 = lut[rs + 0]; | |
entry1.data2 = lut[rs + 1]; | |
entry2.data2 = lut[rs + 2]; | |
entry3.data2 = lut[rs + 3]; | |
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f; | |
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f; | |
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f; | |
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f; | |
out += val0 * entry0.data.funcVal; | |
out += val1 * entry1.data.funcVal; | |
out += val2 * entry2.data.funcVal; | |
out += val3 * entry3.data.funcVal; | |
rs += 4; | |
} | |
*O = out*alpha + O_val*beta; | |
} | |
""" | |
module = SourceModule(code, options=["--use_fast_math"]) | |
kernel = module.get_function("spool_bprop_guassian") | |
kernel.prepare("PPffffffIIIIIIIIIIIIIIIIIII") | |
return kernel | |
from neon.backends.nervanagpu import NervanaGPU | |
ng = NervanaGPU() | |
N,C = (32,1) | |
H,W = (6,6) | |
R,S = (3,3) | |
str_h, str_w = (3,3) | |
var_y, var_x = (1,1) | |
mean_y,mean_x = (0,0) | |
cpu_pool = GaussianPoolCPU( | |
N, C, H, W, R, S, | |
str_h, str_w, | |
var_y, var_x, | |
mean_y,mean_x) | |
#I = np.random.uniform(-1.0, 1.0, cpu_pool.dimI) | |
#E = np.random.uniform(-1.0, 1.0, cpu_pool.dimO) | |
I = np.ones(cpu_pool.dimI) | |
E = np.ones(cpu_pool.dimO) | |
O = np.zeros(cpu_pool.dimO) | |
B = np.zeros(cpu_pool.dimI) | |
cpu_pool.fprop(I, O) | |
cpu_pool.bprop(E, B) | |
print O[0,:,:,0] | |
print B[0,:,:,0] | |
gpu_pool = GaussianPoolGPU( | |
N, C, H, W, R, S, | |
str_h, str_w, | |
var_y, var_x, | |
mean_y,mean_x) | |
I = ng.ones(gpu_pool.dimI) | |
E = ng.ones(gpu_pool.dimO) | |
O = ng.zeros(gpu_pool.dimO) | |
B = ng.zeros(gpu_pool.dimI) | |
gpu_pool.fprop(I, O) | |
gpu_pool.bprop(E, B) | |
print O.get()[0,:,:,0] | |
print B.get()[0,:,:,0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
union LutEntry { | |
struct { | |
int sliceI; | |
float funcVal; | |
} data; | |
int2 data2; | |
}; | |
extern "C" | |
__global__ void spool_bprop_gaussian( | |
float* O, | |
const float* I, | |
int P, | |
int Q, | |
int N, | |
int QN, | |
int PQN, | |
int W, | |
int WN, | |
int HWN, | |
int R, | |
int S, | |
int RS, | |
int magic_S, | |
int shift_S, | |
int stride_h, | |
int stride_w, | |
int magic_stride_h, | |
int shift_stride_h, | |
int magic_stride_w, | |
int shift_stride_w, | |
int pad_h, | |
int pad_w, | |
float var_y, | |
float var_x, | |
float mean_y, | |
float mean_x | |
) | |
{ | |
int __shared__ lutSize; | |
extern __shared__ int2 lut[]; | |
int tid = threadIdx.x; | |
int n = tid; | |
int x = blockIdx.x; | |
int y = blockIdx.y; | |
int c = blockIdx.z; | |
// zigzaq q back and forth to improve L2 cache perf | |
if (y & 1) | |
x = W - x - 1; | |
I += n; | |
O += c*HWN + y*WN + x*N + n; | |
int lut_size; | |
if (tid < 32) | |
{ | |
int pr = y - (R - pad_h - 1); | |
int qs = x - (S - pad_w - 1); | |
int r_half = (R - 1) >> 1; | |
int s_half = (S - 1) >> 1; | |
int chan_offset = c * PQN; | |
float var_y2 = var_y * 0.5f; | |
float var_x2 = var_x * 0.5f; | |
unsigned dep_thd_mask = 0xffffffff; | |
dep_thd_mask >>= 32 - tid; | |
lut_size = 0; | |
int rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
int p_prime = pr + r; | |
int q_prime = qs + s; | |
// Invert kernel coordinates | |
r = R - r - 1; | |
s = S - s - 1; | |
// p = p_prime / stride_h | |
// p_mod = p_prime % stride_h | |
int p = p_prime * magic_stride_h; p >>= shift_stride_h; | |
int p_mod = p_prime - p*stride_h; | |
bool p_bounds = p_mod == 0 && p >= 0 && p < P; | |
// q = q_prime / stride_w | |
// q_mod = q_prime % stride_w | |
int q = q_prime * magic_stride_h; q >>= shift_stride_w; | |
int q_mod = q_prime - q*stride_w; | |
bool q_bounds = q_mod == 0 && q >= 0 && q < Q; | |
bool in_bounds = q_bounds && p_bounds; | |
// Get a mask of all valid slices in the warp | |
unsigned ballot = __ballot(in_bounds); | |
// Count the total valid slices | |
unsigned warp_slices = __popc(ballot); | |
if (in_bounds) | |
{ | |
// Count all the valid slices below this threadid | |
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot); | |
LutEntry entry; | |
entry.data.sliceI = chan_offset + p*QN + q*N; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) ); | |
lut[lut_size + dep_thd_cnt] = entry.data2; | |
} | |
lut_size += warp_slices; | |
rs += 32; | |
} | |
lutSize = lut_size; | |
} | |
__syncthreads(); | |
lut_size = lutSize; | |
int rs = 0; | |
float out = 0.0f; | |
while (rs < lut_size) | |
{ | |
LutEntry entry0; | |
LutEntry entry1; | |
LutEntry entry2; | |
LutEntry entry3; | |
entry0.data2 = lut[rs + 0]; | |
entry1.data2 = lut[rs + 1]; | |
entry2.data2 = lut[rs + 2]; | |
entry3.data2 = lut[rs + 3]; | |
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f; | |
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f; | |
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f; | |
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f; | |
out += val0 * entry0.data.funcVal; | |
out += val1 * entry1.data.funcVal; | |
out += val2 * entry2.data.funcVal; | |
out += val3 * entry3.data.funcVal; | |
rs += 4; | |
} | |
*O = out; | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
union LutEntry { | |
struct { | |
int sliceI; | |
float funcVal; | |
} data; | |
int2 data2; | |
}; | |
extern "C" | |
__global__ void spool_fprop_gaussian( | |
float* O, | |
const float* I, | |
int Q, | |
int N, | |
int QN, | |
int PQN, | |
int H, | |
int W, | |
int WN, | |
int HWN, | |
int R, | |
int S, | |
int RS, | |
int magic_S, | |
int shift_S, | |
int stride_h, | |
int stride_w, | |
int pad_h, | |
int pad_w, | |
float var_y, | |
float var_x, | |
float mean_y, | |
float mean_x | |
) | |
{ | |
int __shared__ lutSize; | |
extern __shared__ int2 lut[]; | |
int tid = threadIdx.x; | |
int n = tid; | |
int q = blockIdx.x; | |
int p = blockIdx.y; | |
int k = blockIdx.z; | |
// zigzaq q back and forth to improve L2 cache perf | |
if (p & 1) | |
q = Q - q - 1; | |
I += n; | |
O += k*PQN + p*QN + q*N + n; | |
int lut_size; | |
if (tid < 32) | |
{ | |
int pr = p * stride_h - pad_h; | |
int qs = q * stride_w - pad_w; | |
int r_half = (R - 1) >> 1; | |
int s_half = (S - 1) >> 1; | |
int chan_offset = k * HWN; | |
float var_y2 = var_y * 0.5f; | |
float var_x2 = var_x * 0.5f; | |
unsigned dep_thd_mask = 0xffffffff; | |
dep_thd_mask >>= 32 - tid; | |
lut_size = 0; | |
int rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
int x = qs + s; | |
int y = pr + r; | |
bool in_bounds = x >= 0 && x < W && y >= 0 && y < H; | |
// Get a mask of all valid slices in the warp | |
unsigned ballot = __ballot(in_bounds); | |
// Count the total valid slices | |
unsigned warp_slices = __popc(ballot); | |
if (in_bounds) | |
{ | |
// Count all the valid slices below this threadid | |
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot); | |
LutEntry entry; | |
entry.data.sliceI = chan_offset + y*WN + x*N; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) ); | |
lut[lut_size + dep_thd_cnt] = entry.data2; | |
} | |
lut_size += warp_slices; | |
rs += 32; | |
} | |
lutSize = lut_size; | |
} | |
__syncthreads(); | |
lut_size = lutSize; | |
int rs = 0; | |
float out = 0.0f; | |
while (rs < lut_size) | |
{ | |
LutEntry entry0; | |
LutEntry entry1; | |
LutEntry entry2; | |
LutEntry entry3; | |
entry0.data2 = lut[rs + 0]; | |
entry1.data2 = lut[rs + 1]; | |
entry2.data2 = lut[rs + 2]; | |
entry3.data2 = lut[rs + 3]; | |
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f; | |
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f; | |
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f; | |
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f; | |
out += val0 * entry0.data.funcVal; | |
out += val1 * entry1.data.funcVal; | |
out += val2 * entry2.data.funcVal; | |
out += val3 * entry3.data.funcVal; | |
rs += 4; | |
} | |
*O = out; | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
union LutEntry { | |
struct { | |
int sliceI; | |
float funcVal; | |
} data; | |
int2 data2; | |
}; | |
extern "C" | |
__global__ void spool_fprop_gaussian_nopad( | |
float* O, | |
const float* I, | |
int Q, | |
int N, | |
int QN, | |
int PQN, | |
int WN, | |
int HWN, | |
int R, | |
int S, | |
int RS, | |
int magic_S, | |
int shift_S, | |
int stride_y, | |
int stride_x, | |
float var_y, | |
float var_x, | |
float mean_y, | |
float mean_x | |
) | |
{ | |
extern __shared__ int2 lut[]; | |
int tid = threadIdx.x; | |
int n = tid; | |
int q = blockIdx.x; | |
int p = blockIdx.y; | |
int k = blockIdx.z; | |
// zigzaq q back and forth to improve L2 cache perf | |
if (p & 1) | |
q = Q - q - 1; | |
I += n; | |
O += k*PQN + p*QN + q*N + n; | |
int pr = p * stride_y; | |
int qs = q * stride_x; | |
int r_half = (R - 1) >> 1; | |
int s_half = (S - 1) >> 1; | |
int chan_offset = k * HWN; | |
float var_y2 = var_y * 0.5f; | |
float var_x2 = var_x * 0.5f; | |
int rs = tid; | |
while (rs < RS) | |
{ | |
// r = rs / S; | |
// s = rs % S; | |
int r = rs * magic_S; r >>= shift_S; | |
int s = rs - r*S; | |
int x = qs + s; | |
int y = pr + r; | |
LutEntry entry; | |
entry.data.sliceI = chan_offset + y*WN + x*N; | |
float fy = (float)(r - r_half) - mean_y; | |
float fx = (float)(s - s_half) - mean_x; | |
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) ); | |
lut[rs] = entry.data2; | |
rs += blockDim.x; | |
} | |
__syncthreads(); | |
rs = 0; | |
float out = 0.0f; | |
while (rs < RS) | |
{ | |
LutEntry entry0; | |
LutEntry entry1; | |
LutEntry entry2; | |
LutEntry entry3; | |
entry0.data2 = lut[rs + 0]; | |
entry1.data2 = lut[rs + 1]; | |
entry2.data2 = lut[rs + 2]; | |
entry3.data2 = lut[rs + 3]; | |
float val0 = rs + 0 < RS ? __ldg(I + entry0.data.sliceI) : 0.0f; | |
float val1 = rs + 1 < RS ? __ldg(I + entry1.data.sliceI) : 0.0f; | |
float val2 = rs + 2 < RS ? __ldg(I + entry2.data.sliceI) : 0.0f; | |
float val3 = rs + 3 < RS ? __ldg(I + entry3.data.sliceI) : 0.0f; | |
out += val0 * entry0.data.funcVal; | |
out += val1 * entry1.data.funcVal; | |
out += val2 * entry2.data.funcVal; | |
out += val3 * entry3.data.funcVal; | |
rs += 4; | |
} | |
*O = out; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment