Skip to content

Instantly share code, notes, and snippets.

@deepankarsharma
Last active May 13, 2024 07:07
Show Gist options
  • Save deepankarsharma/7955e64b423bf39a8bc32304d3be9fe3 to your computer and use it in GitHub Desktop.
Save deepankarsharma/7955e64b423bf39a8bc32304d3be9fe3 to your computer and use it in GitHub Desktop.
avx2 running sum
pub unsafe fn avx_hsum(a: __m256i) -> i32 {
let zero = _mm256_setzero_si256();
let sad = _mm256_sad_epu8(a, zero);
let sum = _mm256_extract_epi16::<0>(sad) as i32 + _mm256_extract_epi16::<4>(sad) + _mm256_extract_epi16::<8>(sad) + _mm256_extract_epi16::<12>(sad);
sum
}
#[inline(always)]
unsafe fn negate_8bit_ints(v: __m256i) -> __m256i {
let zero = _mm256_setzero_si256();
_mm256_sub_epi8(zero, v)
}
pub unsafe fn count_newlines_memmap_avx2_running_sum(filename: &str) -> Result<usize, Error> {
let file = File::open(filename)?;
let mmap = unsafe { Mmap::map(&file)? };
mmap.advise(Advice::Sequential)?;
let newline_byte = b'\n';
let newline_vector = _mm256_set1_epi8(newline_byte as i8);
let mut newline_count = 0;
let mut running_sum = _mm256_setzero_si256();
let mut ptr = mmap.as_ptr();
let end_ptr = unsafe { ptr.add(mmap.len()) };
let mut iteration_count = 0;
while ptr <= end_ptr.sub(32) {
let data = unsafe { _mm256_loadu_si256(ptr as *const __m256i) };
// cmp_result will have -1's for newlines, 0's otherwise
let cmp_result = _mm256_cmpeq_epi8(data, newline_vector);
// since cmp_result has -1's we accumulate negative values here
// we fix those during the call to avx_hsum
running_sum = _mm256_add_epi8(running_sum, cmp_result);
ptr = unsafe { ptr.add(32) };
iteration_count += 1;
if iteration_count % 128 == 0 {
let fixed_running_sum = negate_8bit_ints(running_sum);
newline_count += avx_hsum(fixed_running_sum) as usize;
running_sum = _mm256_setzero_si256();
}
}
// Process remaining iterations
if iteration_count % 128 != 0 {
let fixed_running_sum = negate_8bit_ints(running_sum);
newline_count += avx_hsum(fixed_running_sum) as usize;
}
// Count remaining bytes
let remaining_bytes = end_ptr as usize - ptr as usize;
newline_count += mmap[mmap.len() - remaining_bytes..].iter().filter(|&&b| b == newline_byte).count();
reset_file_caches();
Ok(newline_count)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment