Created
November 9, 2022 17:21
-
-
Save ned14/617ce47171c6324fc388306e5c141633 to your computer and use it in GitHub Desktop.
Many SIMD ways of finding the last zero byte in a fixed length string
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <chrono> | |
#include <climits> | |
#include <cstdint> | |
#include <cstdio> | |
#include <cstdlib> | |
#include <cstring> | |
#include <span> | |
#include <vector> | |
/* | |
Straight C: 5.9, 84.37 | |
SSE2: 4.38, 228 | |
bitscan: 4.8, 86.3 | |
*/ | |
#define BOOST_CHECK(...) if(!(__VA_ARGS__)) {fprintf(stderr, "!(" #__VA_ARGS__ ")\n"); abort(); } | |
namespace utils { | |
/*! \class small_prng | |
\brief From http://burtleburtle.net/bob/rand/smallprng.html, a not awful fast random number source. | |
*/ | |
class small_prng { | |
protected: | |
uint32_t a; | |
uint32_t b; | |
uint32_t c; | |
uint32_t d; | |
static inline uint32_t rot(uint32_t x, uint32_t k) noexcept { return (((x) << (k)) | ((x) >> (32 - (k)))); } | |
public: | |
//! The type produced by the small prng | |
using value_type = uint32_t; | |
//! Construct an instance with `seed` | |
explicit small_prng(uint32_t seed = 0xdeadbeef) noexcept { | |
a = 0xf1ea5eed; | |
b = c = d = seed; | |
for(size_t i = 0; i < 20; ++i) | |
(*this)(); | |
} | |
//! Return `value_type` of pseudo-randomness | |
inline uint32_t operator()() noexcept { | |
uint32_t e = a - rot(b, 27); | |
a = b ^ rot(c, 17); | |
b = c + d; | |
c = d + e; | |
d = e + a; | |
return d; | |
} | |
}; | |
} | |
namespace mdx { | |
/*! \brief A binary symbol identifier. | |
This is a binary symbol identifier of up to 24 bytes in length, padded to | |
the right with all bits zero bytes. Any contiguous run of all bits zero | |
bytes to the right are used to compress storage. | |
The binary symbol identifier may contain all bits zero bytes. `size()` works | |
exclusively with the last non-zero byte. | |
*/ | |
struct symbol { | |
// longest name currently possible is ISIN at 22 bytes | |
char name[24]; | |
constexpr symbol() noexcept | |
: name{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} {} | |
//! Construct from a character array and length | |
symbol(const char* s, size_t l) noexcept | |
: symbol() { | |
if(l > sizeof(name)) { | |
abort(); | |
} | |
memcpy(name, s, l); | |
} | |
//! Construct from a byte array and length | |
// symbol(const byte* s, size_t l) noexcept | |
// : symbol() { | |
// if(l > sizeof(name)) { | |
// abort(); | |
// } | |
// memcpy(name, s, l); | |
// } | |
//! Construct from a span of byte equivalents | |
// template <concepts::byte_equivalent T> | |
// explicit symbol(span<const T> s) noexcept | |
// : symbol(s.data(), s.size()) {} | |
//! Construct from a string view of characters | |
// explicit symbol(string_view s) noexcept | |
// : symbol(s.data(), s.size()) {} | |
//! Construct from a string literal | |
// template <concepts::byte_equivalent T, size_t N> | |
// explicit symbol(const T (&arr)[N]) noexcept | |
// : symbol(reinterpret_cast<const char*>(arr), N - std::is_same_v<std::decay_t<T>, char>) {} | |
//! Equality | |
bool operator==(const symbol& o) const noexcept { return 0 == memcmp(name, o.name, sizeof(name)); } | |
//! Inequality | |
bool operator!=(const symbol& o) const noexcept { return 0 != memcmp(name, o.name, sizeof(name)); } | |
//! Ordering | |
bool operator<(const symbol& o) const noexcept { return memcmp(name, o.name, sizeof(name)) < 0; } | |
//! Returns a pointer to the beginning of the binary symbol identifier | |
char* data() noexcept { return name; } | |
//! \overload | |
const char* data() const noexcept { return name; } | |
/*! Returns the index of the null byte after the very final non-null byte up to `limit`, or `limit` if | |
`name[limit - 1] != 0`. | |
Make SURE that the 24 bytes after `this` are valid to read from before calling this function (the | |
implementation uses SIMD to load either 24 or 16 bytes in a single cycle). | |
*/ | |
inline size_t | |
size_within_maximum_length(size_t limit) const noexcept; // implemented in day_exchange_file_updater.cpp | |
//! Returns the index of the null byte after the very final non-null byte, or `sizeof(name)` if the last byte is | |
//! non-null. | |
size_t size() const noexcept { return size_within_maximum_length(sizeof(name)); } | |
//! Returns true if the symbol's name has zero length. | |
[[nodiscard]] bool empty() const noexcept { return size() == 0; } | |
//! Returns the symbol's name as a string view (note it may contain unprintable characters). | |
//string_view as_string_view() const& noexcept { return string_view(name, size()); } | |
//string_view as_string_view() && = delete; | |
//string_view as_string_view() const&& = delete; | |
}; | |
} | |
#if 0 // defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) | |
#include <emmintrin.h> // for SIMD memrchr implementation | |
namespace mdx { | |
size_t symbol::size_within_maximum_length(size_t limit) const noexcept { | |
if(limit > sizeof(name)) { | |
limit = sizeof(name); | |
} | |
if(name[limit - 1] != 0) { | |
return limit; | |
} | |
auto bsr = [](int value) -> unsigned { | |
#ifdef _MSC_VER | |
unsigned long bitpos; | |
_BitScanReverse(&bitpos, value); | |
return bitpos; | |
#elif defined(__GNUC__) | |
return (sizeof(unsigned) * CHAR_BIT - 1) - (unsigned) __builtin_clz(value); | |
#else | |
#error Unknown compiler | |
#endif | |
}; | |
const __m128i zeros = _mm_setzero_si128(); | |
if(limit > 16) { | |
// We are always 24 bytes long. That is 1.5 SSE registers. | |
const __m128i back = _mm_loadu_si128((const __m128i*) (name + 8)); // don't spill past 24 bytes | |
const __m128i front = _mm_loadu_si128((const __m128i*) (name + 0)); | |
const __m128i cback = _mm_cmpeq_epi8(back, zeros); | |
const __m128i cfront = _mm_cmpeq_epi8(front, zeros); | |
uint16_t mback = ~(uint16_t) _mm_movemask_epi8(cback); | |
const uint8_t mfront = ~(uint8_t) _mm_movemask_epi8(cfront); | |
if(limit < 24) { | |
mback = uint16_t(mback << (24 - limit)) >> (24 - limit); | |
} | |
unsigned lastzeroidx = (mback != 0) ? (9 + bsr(mback)) : ((mfront != 0) ? (1 + bsr(mfront)) : 0); | |
if(lastzeroidx > limit) { | |
abort(); | |
} | |
return lastzeroidx; | |
} else { | |
const __m128i content = _mm_loadu_si128((const __m128i*) (name + 0)); | |
const __m128i cmp = _mm_cmpeq_epi8(content, zeros); | |
uint16_t nonzeros = ~(uint16_t) _mm_movemask_epi8(cmp); | |
if(limit < 16) { | |
nonzeros = uint16_t(nonzeros << (16 - limit)) >> (16 - limit); | |
} | |
unsigned lastzeroidx = (nonzeros != 0) ? (1 + bsr(nonzeros)) : 0; | |
if(lastzeroidx > limit) { | |
abort(); | |
} | |
return lastzeroidx; | |
} | |
} | |
} // namespace mdx | |
#elif defined(__aarch64__) || defined(_M_ARM64) | |
#include <arm_neon.h> // for SIMD memrchr implementation | |
namespace mdx { | |
size_t symbol::size_within_maximum_length(size_t limit) const noexcept { | |
if(limit > sizeof(name)) { | |
limit = sizeof(name); | |
} | |
if(name[limit - 1] != 0) { | |
return limit; | |
} | |
auto bsf = [](int value) -> unsigned { | |
#ifdef _MSC_VER | |
unsigned long bitpos; | |
_BitScanForward(&bitpos, value); | |
return bitpos; | |
#elif defined(__GNUC__) | |
return (unsigned) __builtin_ctz(value); | |
#else | |
#error Unknown compiler | |
#endif | |
}; | |
// Only use four bits per lane as we'll use pairwise addition to collapse the lanes into a byte | |
const uint8x16_t mask_front = [&] { | |
union{uint8x16_t ret; uint8_t ret_bytes[16];}; | |
ret = (uint8x16_t) vdupq_n_u64(0x0102040810204080ULL); | |
if(limit < 16) { | |
memset(ret_bytes+limit, 0, 16-limit); | |
} | |
return ret; | |
}(); | |
const uint8x8_t mask_back = [&] { | |
union{uint8x8_t ret; uint8_t ret_bytes[8];}; | |
if(limit<16) | |
{ | |
ret = vcreate_u8(0); | |
} else { | |
ret = vcreate_u8(0x0102040810204080ULL); | |
if(limit < 24) { | |
memset(ret_bytes+limit-16, 0, 24-limit); | |
} | |
} | |
return ret; | |
}(); | |
if(limit > 16) { | |
// We are always 24 bytes long. NEON unlike SSE can do half registers. | |
const uint8x8_t back = vld1_u8((const uint8_t*) (name + 16)); | |
const uint8x16_t front = vld1q_u8((const uint8_t*) (name + 0)); | |
const uint8x8_t cback = vand_u8(vtst_u8(back, back), mask_back); | |
const uint8x16_t cfront = vandq_u8(vtstq_u8(front, front), mask_front); | |
// mask bits will be set in lanes where there was a non-zero byte. | |
// Now add adjacent lanes to reduce 24 lanes to three lanes (i.e. one bit per whether lane was non-zero) | |
const uint8x8_t rback = vpadd_u8(cback, cback); // 8 to 4 | |
const uint8x8_t rfront = vget_low_u8(vpaddq_u8(cfront, cfront)); // 16 to 8 | |
const uint8x16_t r1 = vcombine_u8(rfront, rback); // 4 + 8 = 12 | |
const uint8x8_t r2 = vget_low_u8(vpaddq_u8(r1, r1)); // 12 to 6 | |
union { | |
uint32_t r3_uints[2]; | |
uint8x8_t r3; | |
}; | |
r3 = vpadd_u8(r2, r2); // 6 to 3 | |
#if 0 | |
auto dump = [](const char *desc, auto x){ printf("%s:", desc); for(size_t n=0; n<sizeof(x); n++){ printf(" %.2x", ((const uint8_t *)&x)[n]);} printf("\n");}; | |
dump(" front", front); | |
dump(" cfront", cfront); | |
dump(" rfront", rfront); | |
printf("\n"); | |
dump(" back", back); | |
dump(" cback", cback); | |
dump(" rback", rback); | |
printf("\n"); | |
dump(" r1", r1); | |
dump(" r2", r2); | |
dump(" r3", r3); | |
#endif | |
const uint32_t m = __builtin_bswap32(r3_uints[0])>>8; | |
//printf("m=%x\n", m); | |
// If the entire input had no zero bytes, m would be 0xff | |
// If the entire input were zero bytes, m would be 0x00 | |
unsigned lastzeroidx = 0; | |
if(m!=0) { | |
lastzeroidx = 24 - bsf(m); | |
} | |
if(lastzeroidx > limit) { | |
abort(); | |
} | |
//printf("lastzeroidx=%u\n", lastzeroidx); | |
return lastzeroidx; | |
} else { | |
const uint8x16_t content = vld1q_u8((const uint8_t*) (name + 0)); | |
const uint8x16_t cmp = vandq_u8(vtstq_u8(content, content), mask_front); | |
const uint8x8_t r1 = vget_low_u8(vpaddq_u8(cmp, cmp)); // 16 to 8 | |
const uint8x8_t r2 = vpadd_u8(r1, r1); // 8 to 4 | |
union { | |
uint16_t r3_uints[4]; | |
uint8x8_t r3; | |
}; | |
r3 = vpadd_u8(r2, r2); // 4 to 2 | |
const uint16_t m = __builtin_bswap16(r3_uints[0]); | |
unsigned lastzeroidx = 0; | |
if(m!=0) { | |
lastzeroidx = 16 - bsf(m); | |
} | |
if(lastzeroidx > limit) { | |
abort(); | |
} | |
return lastzeroidx; | |
} | |
} | |
} // namespace mdx | |
#elif 1 | |
namespace mdx { | |
size_t symbol::size_within_maximum_length(size_t limit) const noexcept { | |
if(limit > sizeof(name)) { | |
limit = sizeof(name); | |
} | |
if(name[limit - 1] != 0) { | |
return limit; | |
} | |
auto bsr = [](uint64_t value) -> unsigned { | |
#ifdef _MSC_VER | |
unsigned long bitpos; | |
63 - _BitScanReverse64(&bitpos, value); | |
return bitpos; | |
#elif defined(__GNUC__) | |
return __builtin_clzll(value); | |
#else | |
#error Unknown compiler | |
#endif | |
}; | |
const uint64_t *v = (const uint64_t *) name; | |
if(limit > 16 && v[2]!=0) | |
{ | |
auto x =v[2]; | |
if(limit < 24) { | |
const auto shift = (24-limit)<<3; | |
x = (x << shift) >> (shift); | |
} | |
if(x!=0) | |
{ | |
return 24 - (bsr(x)>>3); | |
} | |
} | |
if(limit > 8 && v[1]!=0) | |
{ | |
auto x =v[1]; | |
if(limit < 16) { | |
const auto shift = (16-limit)<<3; | |
x = (x << shift) >> (shift); | |
} | |
if(x!=0) | |
{ | |
return 16 - (bsr(x)>>3); | |
} | |
} | |
if(limit > 0 && v[0]!=0) | |
{ | |
auto x =v[0]; | |
if(limit < 8) { | |
const auto shift = (8-limit)<<3; | |
x = (x << shift) >> (shift); | |
} | |
if(x!=0) | |
{ | |
return 8 - (bsr(x)>>3); | |
} | |
} | |
return 0; | |
} | |
} // namespace mdx | |
#else | |
namespace mdx { | |
size_t symbol::size_within_maximum_length(size_t limit) const noexcept { | |
if(limit > sizeof(name)) { | |
limit = sizeof(name); | |
} | |
if(name[limit - 1] != 0) { | |
return limit; | |
} | |
while(limit-- > 0) { | |
if(name[limit] != 0) { | |
return limit + 1; | |
} | |
} | |
return 0; | |
} | |
} // namespace mdx | |
#endif | |
int main(void) | |
{ | |
using namespace mdx; | |
using mdx::symbol; | |
utils::small_prng rand; | |
{ | |
auto begin = std::chrono::high_resolution_clock::now(); | |
while(std::chrono::high_resolution_clock::now()-begin<std::chrono::seconds(1)); | |
} | |
std::vector<std::pair<symbol, size_t>> symbols(50000000); | |
for(size_t n = 0; n < 50000000; n++) { | |
auto* c = (unsigned*) symbols[n].first.name; | |
c[0] = rand(); | |
c[1] = rand(); | |
c[2] = rand(); | |
c[3] = rand(); | |
c[4] = rand(); | |
c[5] = rand(); | |
auto r = rand(); | |
auto &s=symbols[n].first; | |
// Place a random zero byte somewhere | |
s.name[r % 24] = 0; | |
//printf("\n\nzero byte set at %u\n", r % 24); | |
// Make some number of the end zero bytes | |
r >>= 28; | |
auto l = 24 - r; | |
memset(s.name + l, 0, r); | |
while(s.name[l - 1] == 0) { | |
r++; | |
l--; | |
} | |
symbols[n].second=l; | |
} | |
auto begin = std::chrono::high_resolution_clock::now(); | |
for(size_t n = 0; n < 500000000; n++) { | |
auto &s=symbols[n % 50000000].first; | |
auto l=symbols[n % 50000000].second; | |
//printf("\nlength: %zu\n", l); | |
BOOST_CHECK(s.size() == l); | |
#if 1 | |
BOOST_CHECK(s.size_within_maximum_length(l + 1) == l); | |
BOOST_CHECK(s.size_within_maximum_length(l) == l); | |
BOOST_CHECK(s.size_within_maximum_length(l - 1) <= l - 1); | |
#endif | |
} | |
auto end = std::chrono::high_resolution_clock::now(); | |
auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(end-begin); | |
printf("%f\n", diff.count()/1000.0); | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment