Created
July 5, 2020 12:30
-
-
Save syoyo/8e980a1c04f4253894d32da25d68bad4 to your computer and use it in GitHub Desktop.
_mm_min_ps implementation in ARM NEON
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <arm_neon.h> | |
#include <cstdio> | |
#include <limits> | |
#include <cassert> | |
#include <cmath> | |
#include <cstdint> | |
bool check_snan(float f) | |
{ | |
bool is_nan = std::isnan(f); | |
uint32_t val = *reinterpret_cast<uint32_t *>(&f); | |
bool bit_qnan = val & 0x00400000; // qNaN bit | |
return is_nan && (!bit_qnan); | |
} | |
bool check_qnan(float f) | |
{ | |
uint32_t val = *reinterpret_cast<uint32_t *>(&f); | |
bool is_qnan = val & 0x7fc00000; // exp + qNaN bit | |
return is_qnan; | |
} | |
// Check if input is sNaN | |
inline uint32x4_t is_snan(float32x4_t a) | |
{ | |
// all exp bits are 1 and MSB bit of mantissa is 1 | |
const uint32x4_t vsnan_mask = {0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000}; | |
uint32x4_t ret = vceqq_u32(vandq_u32(vreinterpretq_u32_f32(a), vsnan_mask), vsnan_mask); | |
__attribute__((aligned(16))) uint32_t mbuf[4]; | |
vst1q_u32(mbuf, ret); | |
printf("v_is_snan = %x, %x, %x, %x\n", | |
mbuf[0], | |
mbuf[1], | |
mbuf[2], | |
mbuf[3]); | |
return ret; | |
} | |
// Check if input is NaN(sNaN or qNan) | |
inline uint32x4_t is_nan(float32x4_t a) | |
{ | |
const uint32x4_t vexp_mask = {0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000}; | |
const uint32x4_t vmantissa_mask = {0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff}; | |
const uint32x4_t vzero = vdupq_n_u32(0); | |
// Check if all exp bits are 1. | |
uint32x4_t v_exp_all_ones = vceqq_u32(vandq_u32(vreinterpretq_u32_f32(a), vexp_mask), vexp_mask); | |
// Check if any mantissa bits are on(qNaN or sNaN) | |
uint32x4_t v_mantissa_any = vcgtq_u32(vandq_u32(vreinterpretq_u32_f32(a), vmantissa_mask), vzero); | |
uint32x4_t v_is_nan = vandq_u32(v_exp_all_ones, v_mantissa_any); | |
__attribute__((aligned(16))) uint32_t mbuf[4]; | |
vst1q_u32(mbuf, v_is_nan); | |
printf("v_is_nan = %x, %x, %x, %x\n", | |
mbuf[0], | |
mbuf[1], | |
mbuf[2], | |
mbuf[3]); | |
return v_is_nan; | |
} | |
void print_u32(uint32x4_t v, const char *title) | |
{ | |
__attribute__((aligned(16))) uint32_t mbuf[4]; | |
vst1q_u32(mbuf, v); | |
printf("%s = %x, %x, %x, %x\n", | |
title, | |
mbuf[0], | |
mbuf[1], | |
mbuf[2], | |
mbuf[3]); | |
} | |
void print_f32(float32x4_t v, const char *title) | |
{ | |
__attribute__((aligned(16))) float mbuf[4]; | |
vst1q_f32(mbuf, v); | |
printf("%s = %f, %f, %f, %f\n", | |
title, | |
mbuf[0], | |
mbuf[1], | |
mbuf[2], | |
mbuf[3]); | |
} | |
inline float32x4_t vmin(float32x4_t a, float32x4_t b) | |
{ | |
// | |
// Accurate simulation of _mm_min_ps using ARM NEON | |
// | |
// https://www.felixcloutier.com/x86/minps | |
// | |
// when both input are (+/-)0.0, return the second | |
// when the first input is NaN(sNaN or qNaN), return the second. | |
// when the second input is sNaN, return sNaN(return the second). | |
// otherwise return min(a, b) | |
// | |
const uint32x4_t vzero = vdupq_n_f32(0.0f); | |
const uint32x4_t v_src1_is_snan = is_snan(b); | |
// fortunately, ceqq_f32 ignores the sign. | |
const uint32x4_t v_both_are_zeros = vandq_u32(vreinterpretq_u32_f32(vceqq_f32(a, vzero)), | |
vreinterpretq_u32_f32(vceqq_f32(b, vzero))); | |
const uint32x4_t v_src0_is_nan = is_nan(a); | |
const float32x4_t v_min = vminq_f32(a, b); | |
print_u32(v_src0_is_nan, "src0 is NaN"); | |
print_u32(v_src1_is_snan, "src1 is sNaN"); | |
print_u32(v_both_are_zeros, "both src0 and src1 are zero"); | |
float32x4_t v_special_case = vbslq_f32(v_both_are_zeros, b, v_min); | |
print_f32(v_min, "min(a, b)"); | |
print_f32(v_special_case, "after both zero hadling"); | |
v_special_case = vbslq_f32(v_src0_is_nan, b, v_special_case); | |
v_special_case = vbslq_f32(v_src1_is_snan, b, v_special_case); | |
// Requie NaN or both zero case handling? | |
const uint32x4_t v_require_special_handling = vorrq_u32(v_src1_is_snan, vorrq_u32(v_both_are_zeros, v_src0_is_nan)); | |
print_u32(v_require_special_handling, "require special handling"); | |
// use min(a, b) when !(require special handling) | |
float32x4_t ret = vbslq_f32(v_require_special_handling, v_special_case, v_min); | |
return ret; | |
} | |
int main() | |
{ | |
float32x4_t v0, v1, a; | |
__attribute__((aligned(16))) float in0[4]; | |
__attribute__((aligned(16))) float in1[4]; | |
__attribute__((aligned(16))) float buf[4]; | |
{ | |
buf[0] = std::numeric_limits<float>::quiet_NaN(); | |
buf[1] = std::numeric_limits<float>::signaling_NaN(); | |
printf("qnan, snan = %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1]))); | |
} | |
in0[0] = 1.0f; | |
in0[1] = 2.0f; | |
in0[2] = std::numeric_limits<float>::infinity(); | |
in0[3] = std::numeric_limits<float>::max(); | |
in1[0] = 2.0f; | |
in1[1] = 1.0f; | |
in1[2] = std::numeric_limits<float>::max(); | |
in1[3] = std::numeric_limits<float>::infinity(); | |
v0 = vld1q_f32(in0); | |
v1 = vld1q_f32(in1); | |
a = vmin(v0, v1); | |
vst1q_f32(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("--------------\n"); | |
uint32x4_t m = vceqq_f32(v0, v0); | |
__attribute__((aligned(16))) uint32_t mbuf[4]; | |
printf("mbuf = %d, %d, %d, %d\n", | |
mbuf[0], | |
mbuf[1], | |
mbuf[2], | |
mbuf[3]); | |
a = vmaxnmq_f32(vminq_f32(v1, v0), v1); | |
vst1q_f32(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1])), | |
*(reinterpret_cast<uint32_t *>(&buf[2])), | |
*(reinterpret_cast<uint32_t *>(&buf[3]))); | |
printf("--------------\n"); | |
in0[0] = std::numeric_limits<float>::quiet_NaN(); | |
in0[1] = std::numeric_limits<float>::signaling_NaN(); | |
in0[2] = 1.0f; | |
in0[3] = 2.0f; | |
in1[0] = 1.0f; | |
in1[1] = 2.0f; | |
in1[2] = std::numeric_limits<float>::quiet_NaN(); | |
in1[3] = std::numeric_limits<float>::signaling_NaN(); | |
v0 = vld1q_f32(in0); | |
v1 = vld1q_f32(in1); | |
a = vmin(v0, v1); | |
vst1q_f32(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1])), | |
*(reinterpret_cast<uint32_t *>(&buf[2])), | |
*(reinterpret_cast<uint32_t *>(&buf[3]))); | |
assert(std::fabs(buf[0] - 1.0f) < std::numeric_limits<float>::epsilon()); // 1.0 | |
assert(std::fabs(buf[1] - 2.0f) < std::numeric_limits<float>::epsilon()); // 2.0 | |
assert(check_qnan(buf[2])); // qNaN | |
assert(check_qnan(buf[3])); // sNaN | |
printf("--------------\n"); | |
in0[0] = -0.0f; | |
in0[1] = 0.0f; | |
in0[2] = std::numeric_limits<float>::quiet_NaN(); | |
in0[3] = std::numeric_limits<float>::signaling_NaN(); | |
in1[0] = 0.0f; | |
in1[1] = -0.0f; | |
in1[2] = std::numeric_limits<float>::signaling_NaN(); | |
in1[3] = std::numeric_limits<float>::quiet_NaN(); | |
v0 = vld1q_f32(in0); | |
v1 = vld1q_f32(in1); | |
a = vmin(v0, v1); | |
vst1q_f32(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in0 hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&in0[0])), | |
*(reinterpret_cast<uint32_t *>(&in0[1])), | |
*(reinterpret_cast<uint32_t *>(&in0[2])), | |
*(reinterpret_cast<uint32_t *>(&in0[3]))); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("in1 hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&in1[0])), | |
*(reinterpret_cast<uint32_t *>(&in1[1])), | |
*(reinterpret_cast<uint32_t *>(&in1[2])), | |
*(reinterpret_cast<uint32_t *>(&in1[3]))); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1])), | |
*(reinterpret_cast<uint32_t *>(&buf[2])), | |
*(reinterpret_cast<uint32_t *>(&buf[3]))); | |
assert(*(reinterpret_cast<uint32_t *>(&buf[0])) == 0x00000000); // 0.0 | |
assert(*(reinterpret_cast<uint32_t *>(&buf[1])) == 0x80000000); // -0.0 | |
assert(check_snan(buf[2])); // sNaN | |
assert(check_qnan(buf[3])); // qNan | |
return 0; | |
} | |
/* ------------- sse2 -----------------------------------------------*/ | |
#include <xmmintrin.h> | |
#include <cstdio> | |
#include <limits> | |
#include <cstdint> | |
#include <cmath> | |
#include <cassert> | |
bool check_snan(float f) | |
{ | |
bool is_nan = std::isnan(f); | |
uint32_t val = *reinterpret_cast<uint32_t *>(&f); | |
bool bit_qnan = val & 0x00400000; // qNaN bit | |
printf("val = %x, is_nan = %d, bit_qnan = %d\n", val, is_nan, bit_qnan); | |
return is_nan && (!bit_qnan); | |
} | |
bool check_qnan(float f) | |
{ | |
uint32_t val = *reinterpret_cast<uint32_t *>(&f); | |
bool is_qnan = val & 0x7fc00000; // exp + qNaN bit | |
return is_qnan; | |
} | |
int main() | |
{ | |
__m128 v0; | |
__m128 v1; | |
__m128 a; | |
__attribute__((aligned(16))) float in0[4]; | |
__attribute__((aligned(16))) float in1[4]; | |
__attribute__((aligned(16))) float buf[4]; | |
{ | |
buf[0] = std::numeric_limits<float>::quiet_NaN(); | |
buf[1] = std::numeric_limits<float>::signaling_NaN(); | |
printf("qnan, snan = %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1]))); | |
} | |
in0[0] = 1.0f; | |
in0[1] = 2.0f; | |
in0[2] = std::numeric_limits<float>::infinity(); | |
in0[3] = std::numeric_limits<float>::max(); | |
in1[0] = 2.0f; | |
in1[1] = 1.0f; | |
in1[2] = std::numeric_limits<float>::max(); | |
in1[3] = std::numeric_limits<float>::infinity(); | |
v0 = _mm_load_ps(in0); | |
v1 = _mm_load_ps(in1); | |
a = _mm_min_ps(v0, v1); | |
_mm_store_ps(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("--------------\n"); | |
in0[0] = std::numeric_limits<float>::quiet_NaN(); | |
in0[1] = std::numeric_limits<float>::signaling_NaN(); | |
in0[2] = 1.0f; | |
in0[3] = 2.0f; | |
in1[0] = 1.0f; | |
in1[1] = 2.0f; | |
in1[2] = std::numeric_limits<float>::quiet_NaN(); | |
in1[3] = std::numeric_limits<float>::signaling_NaN(); | |
v0 = _mm_load_ps(in0); | |
v1 = _mm_load_ps(in1); | |
a = _mm_min_ps(v0, v1); | |
_mm_store_ps(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1])), | |
*(reinterpret_cast<uint32_t *>(&buf[2])), | |
*(reinterpret_cast<uint32_t *>(&buf[3]))); | |
assert(std::fabs(buf[0] - 1.0f) < std::numeric_limits<float>::epsilon()); // 1.0 | |
assert(std::fabs(buf[1] - 2.0f) < std::numeric_limits<float>::epsilon()); // 2.0 | |
assert(check_qnan(buf[2])); // qNaN | |
assert(check_qnan(buf[3])); // sNaN | |
printf("----------------\n"); | |
in0[0] = -0.0f; | |
in0[1] = 0.0f; | |
in0[2] = std::numeric_limits<float>::quiet_NaN(); | |
in0[3] = std::numeric_limits<float>::signaling_NaN(); | |
in1[0] = 0.0f; | |
in1[1] = -0.0f; | |
in1[2] = std::numeric_limits<float>::signaling_NaN(); | |
in1[3] = std::numeric_limits<float>::quiet_NaN(); | |
v0 = _mm_load_ps(in0); | |
v1 = _mm_load_ps(in1); | |
a = _mm_min_ps(v0, v1); | |
_mm_store_ps(buf, a); | |
printf("in0 = %f, %f, %f, %f\n", | |
in0[0], | |
in0[1], | |
in0[2], | |
in0[3]); | |
printf("in0 hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&in0[0])), | |
*(reinterpret_cast<uint32_t *>(&in0[1])), | |
*(reinterpret_cast<uint32_t *>(&in0[2])), | |
*(reinterpret_cast<uint32_t *>(&in0[3]))); | |
printf("in1 = %f, %f, %f, %f\n", | |
in1[0], | |
in1[1], | |
in1[2], | |
in1[3]); | |
printf("in1 hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&in1[0])), | |
*(reinterpret_cast<uint32_t *>(&in1[1])), | |
*(reinterpret_cast<uint32_t *>(&in1[2])), | |
*(reinterpret_cast<uint32_t *>(&in1[3]))); | |
printf("vmin = %f, %f, %f, %f\n", | |
buf[0], | |
buf[1], | |
buf[2], | |
buf[3]); | |
printf("hex = %x, %x, %x, %x\n", | |
*(reinterpret_cast<uint32_t *>(&buf[0])), | |
*(reinterpret_cast<uint32_t *>(&buf[1])), | |
*(reinterpret_cast<uint32_t *>(&buf[2])), | |
*(reinterpret_cast<uint32_t *>(&buf[3]))); | |
assert(*(reinterpret_cast<uint32_t *>(&buf[0])) == 0x00000000); // 0.0 | |
assert(*(reinterpret_cast<uint32_t *>(&buf[1])) == 0x80000000); // -0.0 | |
assert(check_snan(buf[2])); // sNaN | |
assert(check_qnan(buf[3])); // qNan | |
printf("----------------\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment