Last active
March 29, 2021 11:17
-
-
Save hugohadfield/bff7c95fefb0937470a27c99d9b084da to your computer and use it in GitHub Desktop.
A test of double double arithmetic and numpy + numba
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import pytest | |
import numba | |
from numba import types | |
import numpy as np | |
##### WE ARE GOING TO CREATE A CUSTOM NUMPY DTYPE ##### | |
np_type = np.dtype([('x', np.float64), ('y', np.float64)]) | |
zero = np.zeros(1, dtype=np_type)[0] | |
numba_dtype = numba.from_dtype(np_type) | |
numba_array_type = numba.from_dtype(np_type)[:] | |
##### THESE ARE THE DOUBLE DOUBLE ARITHMETIC FUNCTIONS THAT OUR TYPE WILL IMPLEMENT ##### | |
@numba.njit | |
def _two_sum_quick(x, y): | |
r = x + y | |
e = y - (r - x) | |
return r, e | |
@numba.njit | |
def _two_sum(x, y): | |
r = x + y | |
t = r - x | |
e = (x - (r - t)) + (y - t) | |
return r, e | |
@numba.njit | |
def _two_difference(x, y): | |
r = x - y | |
t = r - x | |
e = (x - (r - t)) - (y + t) | |
return r, e | |
@numba.njit | |
def _two_product(x, y): | |
u = x*134217729.0 | |
v = y*134217729.0 | |
s = u - (u - x) | |
t = v - (v - y) | |
f = x - s | |
g = y - t | |
r = x*y | |
e = ((s*t - r) + s*g + f*t) + f*g | |
return r, e | |
@numba.njit | |
def mul_double_double(ax, bx, ay, by): | |
r, e = _two_product(ax, bx) | |
e = e + ax * by + ay * bx | |
r, e = _two_sum_quick(r, e) | |
return r, e | |
@numba.njit | |
def rmul_double_double(ax, ay, other): | |
r, e = _two_product(other, ax) | |
e = e + other * ay | |
r, e = _two_sum_quick(r, e) | |
return r, e | |
@numba.njit | |
def add_double_double(ax, bx, ay, by): | |
r, e = _two_sum(ax, bx) | |
e = e + ay + by | |
r, e = _two_sum_quick(r, e) | |
return r, e | |
@numba.njit | |
def radd_double_double(ax, ay, other): | |
r, e = _two_sum(other, ax) | |
e = e + ay | |
r, e = _two_sum_quick(r, e) | |
return r, e | |
@numba.njit | |
def numpy_rmul_double_double(a, b): | |
r, e = rmul_double_double(a['x'], a['y'], b) | |
out = np.zeros(1, dtype=numba_dtype)[0] | |
out['x'] = r | |
out['y'] = e | |
return out | |
@numba.njit | |
def numpy_mul_double_double(a, b): | |
r, e = mul_double_double(a['x'], b['x'], a['y'], b['y']) | |
out = np.zeros(1, dtype=numba_dtype)[0] | |
out['x'] = r | |
out['y'] = e | |
return out | |
@numba.njit | |
def numpy_add_double_double(a, b): | |
r, e = add_double_double(a['x'], b['x'], a['y'], b['y']) | |
out = np.zeros(1, dtype=numba_dtype)[0] | |
out['x'] = r | |
out['y'] = e | |
return out | |
@numba.njit | |
def numpy_radd_double_double(a, b): | |
r, e = radd_double_double(a['x'], a['y'], b) | |
out = np.zeros(1, dtype=numba_dtype)[0] | |
out['x'] = r | |
out['y'] = e | |
return out | |
# This is to allocate zeros for arrays of custom dtype | |
@numba.njit | |
def numpy_zeros_array_doubledouble(l): | |
return np.zeros(l, dtype=numba_dtype) | |
##### THIS IS WHERE WE DEFINE THE CUSTOM OVERLOADS ##### | |
@numba.extending.overload(operator.mul) | |
def np_double_double_mul(a, b): | |
# These are the not array versions | |
if a == numba_dtype and b == numba_dtype: | |
def impl(a, b): | |
return numpy_mul_double_double(a, b) | |
return impl | |
elif a == numba_dtype and isinstance(b, types.abstract.Number): | |
def impl(a, b): | |
return numpy_rmul_double_double(a, b) | |
return impl | |
elif b == numba_dtype and isinstance(a, types.abstract.Number): | |
def impl(a, b): | |
return numpy_rmul_double_double(b, a) | |
return impl | |
# Now the array versions | |
elif isinstance(a, type(numba_array_type)): | |
if a.dtype == numba_dtype and isinstance(b, types.abstract.Number): | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(a.shape[0]) | |
for i in range(a.shape[0]): | |
output[i] = numpy_rmul_double_double(a[i], b) | |
return output | |
return impl | |
elif isinstance(b, type(numba_array_type)): | |
if a.dtype == numba_dtype and b.dtype == numba_dtype: | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(a.shape[0]) | |
for i in range(a.shape[0]): | |
output[i] = numpy_mul_double_double(a[i], b[i]) | |
return output | |
return impl | |
elif isinstance(a, types.abstract.Number) and isinstance(b, type(numba_array_type)): | |
if b.dtype == numba_dtype: | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(b.shape[0]) | |
for i in range(b.shape[0]): | |
output[i] = numpy_rmul_double_double(b[i], a) | |
return output | |
return impl | |
@numba.extending.overload(operator.add) | |
def np_double_double_add(a, b): | |
# These are the not array versions | |
if a == numba_dtype and b == numba_dtype: | |
def impl(a, b): | |
return numpy_add_double_double(a, b) | |
return impl | |
elif a == numba_dtype and isinstance(b, types.abstract.Number): | |
def impl(a, b): | |
return numpy_radd_double_double(a, b) | |
return impl | |
elif b == numba_dtype and isinstance(a, types.abstract.Number): | |
def impl(a, b): | |
return numpy_radd_double_double(b, a) | |
return impl | |
# Now the array versions | |
elif isinstance(a, type(numba_array_type)): | |
if a.dtype == numba_dtype and isinstance(b, types.abstract.Number): | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(a.shape[0]) | |
for i in range(a.shape[0]): | |
output[i] = numpy_radd_double_double(a[i], b) | |
return output | |
return impl | |
elif isinstance(b, type(numba_array_type)): | |
if a.dtype == numba_dtype and b.dtype == numba_dtype: | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(a.shape[0]) | |
for i in range(a.shape[0]): | |
output[i] = numpy_add_double_double(a[i], b[i]) | |
return output | |
return impl | |
elif isinstance(a, types.abstract.Number) and isinstance(b, type(numba_array_type)): | |
if b.dtype == numba_dtype: | |
def impl(a, b): | |
output = numpy_zeros_array_doubledouble(b.shape[0]) | |
for i in range(b.shape[0]): | |
output[i] = numpy_radd_double_double(b[i], a) | |
return output | |
return impl | |
##### THESE ARE SOME TEST UTILITIES ##### | |
@pytest.fixture | |
def rng(): | |
default_test_seed = 1 # the default seed to start pseudo-random tests | |
return np.random.default_rng(default_test_seed) | |
def gen_a_b(rng): | |
a = np.zeros(1, dtype=np_type)[0] | |
a['x'] = rng.standard_normal() | |
a['y'] = 0.0 | |
b = np.zeros(1, dtype=np_type)[0] | |
b['x'] = rng.standard_normal() | |
b['y'] = 0.0 | |
return (a, b) | |
##### THESE ARE THE TESTS ##### | |
class TestDoubleDoubleNumpy: | |
def test_mul(self, rng): | |
@numba.njit | |
def mul_test(a, b): | |
return 3.0*a*b*2.0 | |
for i in range(1000): | |
a, b = gen_a_b(rng) | |
r, e = mul_double_double(a['x'], b['x'], a['y'], b['y']) | |
r, e = rmul_double_double(r, e, 2.0) | |
r, e = rmul_double_double(r, e, 3.0) | |
res = mul_test(a, b) | |
np.testing.assert_allclose((res['x'], res['y']), (r, e)) | |
def test_array_mul(self, rng): | |
@numba.njit | |
def test_array_mul(c, d): | |
e = 3.0*c | |
f = d*2.0 | |
return e*f | |
@numba.njit | |
def _test_array_mul(c, d): | |
e = (3.0*c[0], 3.0*c[1]) | |
f = (d[0]*2.0, d[1]*2.0) | |
return (e[0]*f[0], e[1]*f[1]) | |
for i in range(1000): | |
a, b = gen_a_b(rng) | |
c = np.array([a, b]) | |
d = np.array([b, b]) | |
res1 = test_array_mul(c, d) | |
res2 = _test_array_mul(c, d) | |
np.testing.assert_allclose((res1[0]['x'], res1[0]['y']), (res2[0]['x'], res2[0]['y'])) | |
def test_array_add(self, rng): | |
@numba.njit | |
def test_array_add(c, d): | |
return 2.0 + c + d + 5 | |
@numba.njit | |
def _test_array_add(c, d): | |
return (2.0 + c[0] + d[0] + 5, 2.0 + c[1] + d[1] + 5) | |
for i in range(1000): | |
a, b = gen_a_b(rng) | |
c = np.array([a, b]) | |
d = np.array([b, b]) | |
res1 = test_array_add(c, d) | |
res2 = _test_array_add(c, d) | |
np.testing.assert_allclose((res1[0]['x'], res1[0]['y']), (res2[0]['x'], res2[0]['y'])) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment