Instantly share code, notes, and snippets.
Last active
September 16, 2022 01:42
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save kieber-emmons/d3fc0e0afc475d41ea07344a8fdcc58e to your computer and use it in GitHub Desktop.
This gist accompanies a Medium story I wrote about optimizing Parallel Prefix Sum in Metal and comparison to optimized CPU code on Apple M1 (https://kieber-emmons.medium.com/efficient-parallel-prefix-sum-in-metal-for-apple-m1-9e60b974d62).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// ParallelScan.metal | |
// | |
// Copyright © 2021 Matthew Kieber-Emmons. All rights reserved. | |
// For educational purposes only. | |
// | |
#include <metal_stdlib> | |
using namespace metal; | |
//////////////////////////////////////////////////////////////// | |
// MARK: - Functions Constants | |
//////////////////////////////////////////////////////////////// | |
// these constants control the code paths at pipeline creation | |
constant int LOCAL_ALGORITHM [[function_constant(0)]]; | |
constant int GLOBAL_ALGORITHM [[function_constant(1)]]; | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Load and Store Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
// this is a blocked read into registers without bounds checking | |
template<ushort GRAIN_SIZE, typename T> static void | |
LoadBlockedLocalFromGlobal(thread T (&value)[GRAIN_SIZE], | |
const device T* input_data, | |
const ushort local_id) { | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
value[i] = input_data[local_id * GRAIN_SIZE + i]; | |
} | |
} | |
//------------------------------------------------------------------------------------------------// | |
// this is a blocked read into registers with bounds checking | |
template<ushort GRAIN_SIZE, typename T> static void | |
LoadBlockedLocalFromGlobal(thread T (&value)[GRAIN_SIZE], | |
const device T* input_data, | |
const ushort local_id, | |
const uint n, | |
const T substitution_value) { | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
value[i] = (local_id * GRAIN_SIZE + i < n) ? input_data[local_id * GRAIN_SIZE + i] : substitution_value; | |
} | |
} | |
//------------------------------------------------------------------------------------------------// | |
// this is a blocked write into global without bounds checking | |
template<ushort GRAIN_SIZE, typename T> static void | |
StoreBlockedLocalToGlobal(device T *output_data, | |
thread const T (&value)[GRAIN_SIZE], | |
const ushort local_id) { | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
output_data[local_id * GRAIN_SIZE + i] = value[i]; | |
} | |
} | |
//------------------------------------------------------------------------------------------------// | |
// this is a blocked write into global without bounds checking | |
template<ushort GRAIN_SIZE, typename T> static void | |
StoreBlockedLocalToGlobal(device T *output_data, | |
thread const T (&value)[GRAIN_SIZE], | |
const ushort local_id, | |
const uint n) { | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
if (local_id * GRAIN_SIZE + i < n) | |
output_data[local_id * GRAIN_SIZE + i] = value[i]; | |
} | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Thread Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
template<ushort LENGTH, typename T> | |
static inline T ThreadPrefixInclusiveSum(thread T (&values)[LENGTH]){ | |
for (ushort i = 1; i < LENGTH; i++){ | |
values[i] += values[i - 1]; | |
} | |
return values[LENGTH - 1]; | |
} | |
template<ushort LENGTH, typename T> | |
static inline T ThreadPrefixInclusiveSum(threadgroup T* values){ | |
for (ushort i = 1; i < LENGTH; i++){ | |
values[i] += values[i - 1]; | |
} | |
return values[LENGTH - 1]; | |
} | |
//------------------------------------------------------------------------------------------------// | |
template<ushort LENGTH, typename T> | |
static inline T ThreadPrefixExclusiveSum(thread T (&values)[LENGTH]){ | |
// do as an inclusive scan first | |
T inclusive_prefix = ThreadPrefixInclusiveSum<LENGTH>(values); | |
// convert to an exclusive scan in the reverse direction | |
for (ushort i = LENGTH - 1; i > 0; i--){ | |
values[i] = values[i - 1]; | |
} | |
values[0] = 0; | |
return inclusive_prefix; | |
} | |
template<ushort LENGTH, typename T> | |
static inline T ThreadPrefixExclusiveSum(threadgroup T* values){ | |
// do as an inclusive scan first | |
T inclusive_prefix = ThreadPrefixInclusiveSum<LENGTH>(values); | |
// convert to an exclusive scan in the reverse direction | |
for (ushort i = LENGTH - 1; i > 0; i--){ | |
values[i] = values[i - 1]; | |
} | |
values[0] = 0; | |
return inclusive_prefix; | |
} | |
//------------------------------------------------------------------------------------------------// | |
template<ushort LENGTH, typename T> static inline void | |
ThreadUniformAdd(thread T (&values)[LENGTH], T uni){ | |
for (ushort i = 0; i < LENGTH; i++){ | |
values[i] += uni; | |
} | |
} | |
template<ushort LENGTH, typename T> static inline void | |
ThreadUniformAdd(threadgroup T* values, T uni){ | |
for (ushort i = 0; i < LENGTH; i++){ | |
values[i] += uni; | |
} | |
} | |
//------------------------------------------------------------------------------------------------// | |
template<ushort LENGTH, typename T> static inline T | |
ThreadReduce(thread T (&values)[LENGTH]) { | |
T reduction = values[0]; | |
for (ushort i = 1; i < LENGTH; i++){ | |
reduction += values[i]; | |
} | |
return reduction; | |
} | |
//------------------------------------------------------------------------------------------------// | |
template<ushort LENGTH, typename T> static inline T | |
ThreadReduce(threadgroup T* values) { | |
T reduction = values[0]; | |
for (ushort i = 1; i < LENGTH; i++){ | |
reduction += values[i]; | |
} | |
return reduction; | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Threadgroup Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
// Work efficient exclusive scan in shared memory from Blelloch 1990 | |
template<ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupBlellochUnoptimizedPrefixExclusiveSum(T value, threadgroup T* sdata, const ushort lid) { | |
// load input into shared memory | |
sdata[lid] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
const ushort ai = 2 * lid + 1; | |
const ushort bi = 2 * lid + 2; | |
// build the sum in place up the tree | |
ushort stride = 1; | |
for (ushort n = BLOCK_SIZE / 2; n > 0; n /= 2){ | |
if (lid < n) { | |
sdata[stride * bi - 1] += sdata[stride * ai - 1]; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
stride *= 2; | |
} | |
// clear and optionally store the last element | |
if (lid == 0) { sdata[BLOCK_SIZE - 1] = 0; } | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// traverse down the tree building the scan in place | |
for (ushort n = 1; n < BLOCK_SIZE; n *= 2) { | |
stride /= 2; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (lid < n) { | |
T temp = sdata[stride * ai - 1]; | |
sdata[stride * ai - 1] = sdata[stride * bi - 1]; | |
sdata[stride * bi - 1] += temp; | |
} | |
} | |
// return result | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return sdata[lid]; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// Optimized version of the Blelloch Scan | |
template<ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupBlellochPrefixExclusiveSum(T value, threadgroup T* sdata, const ushort lid) { | |
// store values to shared memory | |
sdata[lid] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
const ushort ai = 2 * lid + 1; | |
const ushort bi = 2 * lid + 2; | |
// build the sum in place up the tree | |
if (BLOCK_SIZE >= 2) {if (lid < (BLOCK_SIZE >> 1) ) {sdata[ 1 * bi - 1] += sdata[ 1 * ai - 1];} if ((BLOCK_SIZE >> 0) > 32) threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 4) {if (lid < (BLOCK_SIZE >> 2) ) {sdata[ 2 * bi - 1] += sdata[ 2 * ai - 1];} if ((BLOCK_SIZE >> 1) > 32) threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 8) {if (lid < (BLOCK_SIZE >> 3) ) {sdata[ 4 * bi - 1] += sdata[ 4 * ai - 1];} if ((BLOCK_SIZE >> 2) > 32) threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 16) {if (lid < (BLOCK_SIZE >> 4) ) {sdata[ 8 * bi - 1] += sdata[ 8 * ai - 1];} if ((BLOCK_SIZE >> 3) > 32) threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 32) {if (lid < (BLOCK_SIZE >> 5) ) {sdata[ 16 * bi - 1] += sdata[ 16 * ai - 1];} if ((BLOCK_SIZE >> 4) > 32) threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 64) {if (lid < (BLOCK_SIZE >> 6) ) {sdata[ 32 * bi - 1] += sdata[ 32 * ai - 1];} } | |
if (BLOCK_SIZE >= 128) {if (lid < (BLOCK_SIZE >> 7) ) {sdata[ 64 * bi - 1] += sdata[ 64 * ai - 1];} } | |
if (BLOCK_SIZE >= 256) {if (lid < (BLOCK_SIZE >> 8) ) {sdata[ 128 * bi - 1] += sdata[ 128 * ai - 1];} } | |
if (BLOCK_SIZE >= 512) {if (lid < (BLOCK_SIZE >> 9) ) {sdata[ 256 * bi - 1] += sdata[ 256 * ai - 1];} } | |
if (BLOCK_SIZE >= 1024) {if (lid < (BLOCK_SIZE >> 10) ) {sdata[ 512 * bi - 1] += sdata[ 512 * ai - 1];} } | |
// clear and optionally store the last element | |
if (lid == 0){ | |
sdata[BLOCK_SIZE - 1] = 0; | |
} | |
threadgroup_barrier(metal::mem_flags::mem_threadgroup); | |
// traverse down the tree building the scan in place | |
if (BLOCK_SIZE >= 2){ | |
if (lid < 1) { | |
sdata[(BLOCK_SIZE >> 1) * bi - 1] += sdata[(BLOCK_SIZE >> 1) * ai - 1]; | |
sdata[(BLOCK_SIZE >> 1) * ai - 1] = sdata[(BLOCK_SIZE >> 1) * bi - 1] - sdata[(BLOCK_SIZE >> 1) * ai - 1]; | |
} | |
} | |
if (BLOCK_SIZE >= 4){ if (lid < 2) {sdata[(BLOCK_SIZE >> 2) * bi - 1] += sdata[(BLOCK_SIZE >> 2) * ai - 1]; sdata[(BLOCK_SIZE >> 2) * ai - 1] = sdata[(BLOCK_SIZE >> 2) * bi - 1] - sdata[(BLOCK_SIZE >> 2) * ai - 1];} } | |
if (BLOCK_SIZE >= 8){ if (lid < 4) {sdata[(BLOCK_SIZE >> 3) * bi - 1] += sdata[(BLOCK_SIZE >> 3) * ai - 1]; sdata[(BLOCK_SIZE >> 3) * ai - 1] = sdata[(BLOCK_SIZE >> 3) * bi - 1] - sdata[(BLOCK_SIZE >> 3) * ai - 1];} } | |
if (BLOCK_SIZE >= 16){ if (lid < 8) {sdata[(BLOCK_SIZE >> 4) * bi - 1] += sdata[(BLOCK_SIZE >> 4) * ai - 1]; sdata[(BLOCK_SIZE >> 4) * ai - 1] = sdata[(BLOCK_SIZE >> 4) * bi - 1] - sdata[(BLOCK_SIZE >> 4) * ai - 1];} } | |
if (BLOCK_SIZE >= 32){ if (lid < 16) {sdata[(BLOCK_SIZE >> 5) * bi - 1] += sdata[(BLOCK_SIZE >> 5) * ai - 1]; sdata[(BLOCK_SIZE >> 5) * ai - 1] = sdata[(BLOCK_SIZE >> 5) * bi - 1] - sdata[(BLOCK_SIZE >> 5) * ai - 1];} } | |
if (BLOCK_SIZE >= 64){ if (lid < 32) {sdata[(BLOCK_SIZE >> 6) * bi - 1] += sdata[(BLOCK_SIZE >> 6) * ai - 1]; sdata[(BLOCK_SIZE >> 6) * ai - 1] = sdata[(BLOCK_SIZE >> 6) * bi - 1] - sdata[(BLOCK_SIZE >> 6) * ai - 1];} threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 128){ if (lid < 64) {sdata[(BLOCK_SIZE >> 7) * bi - 1] += sdata[(BLOCK_SIZE >> 7) * ai - 1]; sdata[(BLOCK_SIZE >> 7) * ai - 1] = sdata[(BLOCK_SIZE >> 7) * bi - 1] - sdata[(BLOCK_SIZE >> 7) * ai - 1];} threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 256){ if (lid < 128) {sdata[(BLOCK_SIZE >> 8) * bi - 1] += sdata[(BLOCK_SIZE >> 8) * ai - 1]; sdata[(BLOCK_SIZE >> 8) * ai - 1] = sdata[(BLOCK_SIZE >> 8) * bi - 1] - sdata[(BLOCK_SIZE >> 8) * ai - 1];} threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 512){ if (lid < 256) {sdata[(BLOCK_SIZE >> 9) * bi - 1] += sdata[(BLOCK_SIZE >> 9) * ai - 1]; sdata[(BLOCK_SIZE >> 9) * ai - 1] = sdata[(BLOCK_SIZE >> 9) * bi - 1] - sdata[(BLOCK_SIZE >> 9) * ai - 1];} threadgroup_barrier(mem_flags::mem_threadgroup); } | |
if (BLOCK_SIZE >= 1024){ if (lid < 512) {sdata[(BLOCK_SIZE >> 10) * bi - 1] += sdata[(BLOCK_SIZE >> 10) * ai - 1]; sdata[(BLOCK_SIZE >> 10) * ai - 1] = sdata[(BLOCK_SIZE >> 10) * bi - 1] - sdata[(BLOCK_SIZE >> 10) * ai - 1];} threadgroup_barrier(mem_flags::mem_threadgroup); } | |
return sdata[lid]; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// Raking threadgroup scan | |
template<ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupRakingPrefixExclusiveSum(T value, threadgroup T* shared, const ushort lid) { | |
// load values into shared memory | |
shared[lid] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// only the first 32 threads form the rake | |
if (lid < 32){ | |
// scan by thread in shared mem | |
const short values_per_thread = BLOCK_SIZE / 32; | |
const short first_index = lid * values_per_thread; | |
for (short i = first_index + 1; i < first_index + values_per_thread; i++){ | |
shared[i] += shared[i - 1]; | |
} | |
T partial_sum = shared[first_index + values_per_thread - 1]; | |
for (short i = first_index + values_per_thread - 1; i > first_index; i--){ | |
shared[i] = shared[i - 1]; | |
} | |
shared[first_index] = 0; | |
// scan the partial sums | |
T prefix = simd_prefix_exclusive_sum(partial_sum); | |
// add back the prefix | |
for (short i = first_index; i < first_index + values_per_thread; i++){ | |
shared[i] += prefix; | |
} | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return shared[lid]; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// Cooperative threadgroup scan | |
template<ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupCooperativePrefixExclusiveSum(T value, threadgroup T* sdata, const ushort lid) { | |
// first level of reduction in simdgroup | |
T scan = 0; | |
scan = simd_prefix_exclusive_sum(value); | |
// return early if our block size is 32 | |
if (BLOCK_SIZE == 32){ | |
return scan; | |
} | |
// store inclusive sums into shared[0...31] | |
if ( (lid % 32) == (32 - 1) ){ | |
sdata[lid / 32] = scan + value; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// scan the shared memory | |
if (lid < 32) { | |
sdata[lid] = simd_prefix_exclusive_sum(sdata[lid]); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// the scan is the sum of the partial sum prefix scan and the original value | |
return scan + sdata[lid / 32]; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// This kernel is a work efficent but moderately cost inefficient reduction in shared memory. | |
// Kernel is inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris: | |
// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf | |
template <ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupReduceSharedMemAlgorithm(T value, threadgroup T* shared, const ushort lid){ | |
// copy values to shared memory | |
shared[lid] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// initial reductions in shared memory | |
if (BLOCK_SIZE >= 1024) {if (lid < 512) {shared[lid] += shared[lid + 512];} threadgroup_barrier(mem_flags::mem_threadgroup);} | |
if (BLOCK_SIZE >= 512) {if (lid < 256) {shared[lid] += shared[lid + 256];} threadgroup_barrier(mem_flags::mem_threadgroup);} | |
if (BLOCK_SIZE >= 256) {if (lid < 128) {shared[lid] += shared[lid + 128];} threadgroup_barrier(mem_flags::mem_threadgroup);} | |
if (BLOCK_SIZE >= 128) {if (lid < 64) {shared[lid] += shared[lid + 64];} threadgroup_barrier(mem_flags::mem_threadgroup);} | |
// final reduction in shared warp | |
if (lid < 32){ | |
// we fold one more time | |
if (BLOCK_SIZE >= 64) { | |
shared[lid] += shared[lid + 32]; | |
simdgroup_barrier(mem_flags::mem_threadgroup); | |
} | |
value = simd_sum(shared[lid]); | |
} | |
// only result in thread0 is defined | |
return value; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// This kernel is a work and cost efficent rake in shared memory. | |
// Kernel is inspired by CUB library by NVIDIA | |
template <ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupReduceRakingAlgorithm(T value, threadgroup T* shared, const ushort lid){ | |
// copy values to shared memory | |
shared[lid] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// first warp reduces all values | |
if (lid < 32){ | |
// interleaved addressing to reduce values into 0...31 | |
for (short i = 1; i < BLOCK_SIZE / 32; i++){ | |
shared[lid] += shared[lid + 32 * i]; | |
} | |
simdgroup_barrier(mem_flags::mem_threadgroup); | |
// final reduction | |
value = simd_sum(shared[lid]); | |
} | |
// only result in thread0 is defined | |
return value; | |
} | |
//------------------------------------------------------------------------------------------------// | |
// This is a highly parallel but not cost efficient algorithm | |
template <ushort BLOCK_SIZE, typename T> static T | |
ThreadgroupReduceCooperativeAlgorithm(T value, threadgroup T* shared, const ushort lid){ | |
// first level of reduction in simdgroup | |
value = simd_sum(value); | |
// return early if our block size is 32 | |
if (BLOCK_SIZE == 32){ | |
return value; | |
} | |
// first simd lane writes to shared | |
if (lid % 32 == 0) | |
shared[lid / 32] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// final reduction in first simdgroup | |
if (lid < 32){ | |
// mask the values on copy | |
value = (lid < BLOCK_SIZE / 32) ? shared[lid] : 0; | |
// final reduction | |
value = simd_sum(value); | |
} | |
// only result in thread0 is defined unless requested | |
return value; | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Multi-level scan kernels | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, typename T> kernel void | |
PrefixScanKernel(device T* output_data, | |
device const T* input_data, | |
constant uint& n, | |
device T* partial_sums, | |
uint group_id [[threadgroup_position_in_grid]], | |
ushort local_id [[thread_position_in_threadgroup]]) { | |
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE; | |
// load values into registers | |
T values[GRAIN_SIZE]; | |
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id); | |
// sequentially scan the values in registers in place | |
T aggregate = ThreadPrefixExclusiveSum<GRAIN_SIZE>(values); | |
// scan the aggregates | |
T prefix = 0; | |
threadgroup T scratch[BLOCK_SIZE]; | |
switch (LOCAL_ALGORITHM){ | |
case 0: | |
prefix = ThreadgroupBlellochPrefixExclusiveSum<BLOCK_SIZE,T>(aggregate, scratch, local_id); | |
break; | |
case 1: | |
prefix = ThreadgroupRakingPrefixExclusiveSum<BLOCK_SIZE,T>(aggregate, scratch, local_id); | |
break; | |
case 2: | |
prefix = ThreadgroupCooperativePrefixExclusiveSum<BLOCK_SIZE,T>(aggregate, scratch, local_id); | |
break; | |
} | |
// optionally load or store the inclusive sum as needed | |
switch(GLOBAL_ALGORITHM){ | |
case 0: | |
// no op | |
break; | |
case 1: | |
if (local_id == BLOCK_SIZE - 1) | |
partial_sums[group_id] = aggregate + prefix; | |
threadgroup_barrier(mem_flags::mem_none); | |
break; | |
case 2: | |
if (local_id == 0) | |
scratch[0] = partial_sums[group_id]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
prefix += scratch[0]; | |
break; | |
} | |
// sequentially add the scan and prefix to the values in place | |
ThreadUniformAdd<GRAIN_SIZE>(values, prefix); | |
// store to global | |
StoreBlockedLocalToGlobal(&output_data[base_id], values, local_id); | |
} | |
#if defined(THREADS_PER_THREADGROUP) && defined(VALUES_PER_THREAD) | |
template [[host_name("prefix_exclusive_scan_uint32")]] kernel void PrefixScanKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD>(device uint*, device const uint*,constant uint&,device uint*,uint,ushort); | |
#endif | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, typename T> kernel void | |
ReduceKernel(device T* output_data, | |
device const T* input_data, | |
constant uint& n, | |
uint group_id [[ threadgroup_position_in_grid ]], | |
ushort local_id [[ thread_index_in_threadgroup ]]) { | |
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE; | |
// load from global | |
T values[GRAIN_SIZE]; | |
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id); | |
// reduce by thread | |
T value = ThreadReduce<GRAIN_SIZE>(values); | |
// reduce the values from each thread in the threadgroup | |
threadgroup T scratch[BLOCK_SIZE]; | |
switch (LOCAL_ALGORITHM){ | |
case 0: | |
value = ThreadgroupReduceSharedMemAlgorithm<BLOCK_SIZE>(value, scratch, local_id); | |
break; | |
case 1: | |
value = ThreadgroupReduceRakingAlgorithm<BLOCK_SIZE>(value, scratch, local_id); | |
break; | |
case 2: | |
value = ThreadgroupReduceCooperativeAlgorithm<BLOCK_SIZE>(value, scratch, local_id); | |
break; | |
} | |
// write result to global memory | |
if (local_id == 0) | |
output_data[group_id] = value; | |
} | |
#if defined(THREADS_PER_THREADGROUP) && defined(VALUES_PER_THREAD) | |
template [[host_name("reduce_uint32")]] kernel void ReduceKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD>(device uint*, device const uint*,constant uint&,uint,ushort); | |
#endif | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment