Skip to content

Instantly share code, notes, and snippets.

@guilhermeleobas
Last active June 2, 2021 17:48
Show Gist options
  • Save guilhermeleobas/52ac0e58dbdea324cf3c6861d276abb1 to your computer and use it in GitHub Desktop.
Save guilhermeleobas/52ac0e58dbdea324cf3c6861d276abb1 to your computer and use it in GitHub Desktop.
from numba import njit
from numba.core import types, errors
from numba.core.types import Array, float64, boolean, Omitted, int64
from numba.np.numpy_support import type_can_asarray
from numba.core.extending import overload, register_jitable
import numpy as np
import pytest
import os
assert 'IMPL' in os.environ
if os.environ['IMPL'] == 'guilherme':
print('Using Guilherme implementation')
@register_jitable
def _within_tol(a, b, rtol, atol):
return np.less_equal(np.abs(a - b), atol + rtol * np.abs(b))
@overload(np.isclose)
def np_isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
# Based on NumPy impl.
# https://github.com/numpy/numpy/blob/d9b1e32cb8ef90d6b4a47853241db2a28146a57d/numpy/core/numeric.py#L2180-L2292
if not(type_can_asarray(a) and type_can_asarray(b)):
raise errors.TypingError("Inputs for `np.isclose` must be array-like.")
def impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
xfin = np.asarray(np.isfinite(a))
yfin = np.asarray(np.isfinite(b))
x, y = np.asarray(a), np.asarray(b)
if np.all(xfin) and np.all(yfin):
return _within_tol(x, y, rtol, atol)
else:
finite = xfin & yfin
r = np.zeros_like(finite)
x = x * np.ones_like(r)
y = y * np.ones_like(r)
r = _within_tol(x, y, rtol, atol)
# Negate every element that is not finite
r &= (xfin & yfin)
# Check for equality of infinite values
r |= (x == np.asarray(np.inf)) & (y == np.asarray(np.inf))
r |= (x == np.asarray(-np.inf)) & (y == np.asarray(-np.inf))
if equal_nan:
xnan = np.asarray(np.isnan(a))
ynan = np.asarray(np.isnan(b))
return r | (xnan & ynan)
else:
return r
return impl
else: # Jim Pivarski implementation
print('Using Jim Pivarski implementation!')
@register_jitable
def _isclose_item(x, y, rtol, atol, equal_nan):
if np.isnan(x) and np.isnan(y):
return equal_nan
elif np.isinf(x) and np.isinf(y):
return (x > 0) == (y > 0)
elif np.isinf(x) or np.isinf(y):
return False
else:
return abs(x - y) <= atol + rtol * abs(y)
@overload(np.isclose)
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
if (isinstance(a, types.Array) and a.ndim > 0) or (
isinstance(b, types.Array) and b.ndim > 0
):
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
# FIXME: want to broadcast_arrays(a, b) here
x = a.reshape(-1)
y = b.reshape(-1)
out = np.zeros(len(y), np.bool_)
for i in range(len(out)):
out[i] = _isclose_item(x[i], y[i], rtol, atol, equal_nan)
return out.reshape(b.shape)
elif isinstance(a, types.Array) or isinstance(b, types.Array):
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return np.asarray(
_isclose_item(a.item(), b.item(), rtol, atol, equal_nan)
)
else:
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return _isclose_item(a, b, rtol, atol, equal_nan)
return isclose_impl
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return np.isclose(a, b, rtol, atol, equal_nan)
def values():
rtol = 1e-5
atol = 1e-8
arr = np.array([100, 1000])
aran = np.arange(8).reshape((2, 2, 2))
kw = {'rtol': rtol, 'atol': atol}
yield 1e10, 1.00001e10, {}
yield 1e10, np.nan, {}
yield [1e-8, 1e-7], [0.0, 0.0], {}
yield [1e10, 1e-7], [1.00001e10, 1e-8], {}
yield [1e10, 1e-8], [1.00001e10, 1e-9], {}
yield [1e10, 1e-8], [1.0001e10, 1e-9], {}
yield [1.0, np.nan], [1.0, np.nan], {}
yield [1.0, np.nan],[1.0, np.nan], {'equal_nan': True}
yield [np.nan, np.nan], [1.0, np.nan], {'equal_nan': True}
yield [1e-100, 1e-7], [0.0, 0.0], {'atol': 0.0}
yield [1e-10, 1e-10], [1e-20, 0.0], {}
yield [1e-10, 1e-10], [1e-20, 0.999999e-10], {'atol': 0.0}
yield [1, np.inf, 2], [3, np.inf, 4], kw
# tests taken from
# https://github.com/numpy/numpy/blob/aac965af6032b69d5cb515ad785cc9a331e816f4/numpy/core/tests/test_numeric.py#L2298-L2335 # noqa: E501
# all close tests
yield [0, 1], [1, 0], kw
yield arr, arr, kw
yield [1], [1 + rtol + atol], kw
yield arr, arr + arr * rtol, kw
yield arr, arr + arr * rtol + atol, kw
yield aran, aran + aran * rtol, kw
yield np.inf, np.inf, kw
yield -np.inf, np.inf, kw
yield np.inf, [np.inf], kw
yield [np.inf, -np.inf], [np.inf, -np.inf], kw
# none close tests
yield [np.inf, 0], [1, np.inf], kw
yield [np.inf, -np.inf], [1, 0], kw
yield [np.inf, np.inf], [1, -np.inf], kw
yield [np.inf, np.inf], [1, 0], kw
yield [np.nan, 0], [np.nan, -np.inf], kw
yield [atol * 2], [0], kw
yield [1], [1 + rtol + atol * 2], kw
yield aran, aran + rtol * 1.1 * aran + atol * 1.1, kw
yield np.array([np.inf, 1]), np.array([0, np.inf]), kw
# some close tests
yield [np.inf, 0], [np.inf, atol * 2], kw
yield [atol, 1, 1e6 * (1 + 2 * rtol) + atol], [0, np.nan, 1e6], kw
yield np.arange(3), [0, 1, 2.1], kw
yield np.nan, [np.nan, np.nan, np.nan], kw
yield [0], [atol, np.inf, -np.inf, np.nan], kw
yield 0, [atol, np.inf, -np.inf, np.nan], kw
pyfunc = isclose
cfunc = njit(isclose)
# warmup
for a, b, kwargs in values():
a = np.asarray(a)
b = np.asarray(b)
cfunc(a, b, **kwargs)
@pytest.mark.parametrize('a, b, kwargs', values())
def test_foo(a, b, kwargs):
a = np.asarray(a)
b = np.asarray(b)
cfunc(a, b, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment