Last active
June 2, 2021 17:48
-
-
Save guilhermeleobas/52ac0e58dbdea324cf3c6861d276abb1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import njit | |
from numba.core import types, errors | |
from numba.core.types import Array, float64, boolean, Omitted, int64 | |
from numba.np.numpy_support import type_can_asarray | |
from numba.core.extending import overload, register_jitable | |
import numpy as np | |
import pytest | |
import os | |
assert 'IMPL' in os.environ | |
if os.environ['IMPL'] == 'guilherme': | |
print('Using Guilherme implementation') | |
@register_jitable | |
def _within_tol(a, b, rtol, atol): | |
return np.less_equal(np.abs(a - b), atol + rtol * np.abs(b)) | |
@overload(np.isclose) | |
def np_isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
# Based on NumPy impl. | |
# https://github.com/numpy/numpy/blob/d9b1e32cb8ef90d6b4a47853241db2a28146a57d/numpy/core/numeric.py#L2180-L2292 | |
if not(type_can_asarray(a) and type_can_asarray(b)): | |
raise errors.TypingError("Inputs for `np.isclose` must be array-like.") | |
def impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
xfin = np.asarray(np.isfinite(a)) | |
yfin = np.asarray(np.isfinite(b)) | |
x, y = np.asarray(a), np.asarray(b) | |
if np.all(xfin) and np.all(yfin): | |
return _within_tol(x, y, rtol, atol) | |
else: | |
finite = xfin & yfin | |
r = np.zeros_like(finite) | |
x = x * np.ones_like(r) | |
y = y * np.ones_like(r) | |
r = _within_tol(x, y, rtol, atol) | |
# Negate every element that is not finite | |
r &= (xfin & yfin) | |
# Check for equality of infinite values | |
r |= (x == np.asarray(np.inf)) & (y == np.asarray(np.inf)) | |
r |= (x == np.asarray(-np.inf)) & (y == np.asarray(-np.inf)) | |
if equal_nan: | |
xnan = np.asarray(np.isnan(a)) | |
ynan = np.asarray(np.isnan(b)) | |
return r | (xnan & ynan) | |
else: | |
return r | |
return impl | |
else: # Jim Pivarski implementation | |
print('Using Jim Pivarski implementation!') | |
@register_jitable | |
def _isclose_item(x, y, rtol, atol, equal_nan): | |
if np.isnan(x) and np.isnan(y): | |
return equal_nan | |
elif np.isinf(x) and np.isinf(y): | |
return (x > 0) == (y > 0) | |
elif np.isinf(x) or np.isinf(y): | |
return False | |
else: | |
return abs(x - y) <= atol + rtol * abs(y) | |
@overload(np.isclose) | |
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
if (isinstance(a, types.Array) and a.ndim > 0) or ( | |
isinstance(b, types.Array) and b.ndim > 0 | |
): | |
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
# FIXME: want to broadcast_arrays(a, b) here | |
x = a.reshape(-1) | |
y = b.reshape(-1) | |
out = np.zeros(len(y), np.bool_) | |
for i in range(len(out)): | |
out[i] = _isclose_item(x[i], y[i], rtol, atol, equal_nan) | |
return out.reshape(b.shape) | |
elif isinstance(a, types.Array) or isinstance(b, types.Array): | |
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
return np.asarray( | |
_isclose_item(a.item(), b.item(), rtol, atol, equal_nan) | |
) | |
else: | |
def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
return _isclose_item(a, b, rtol, atol, equal_nan) | |
return isclose_impl | |
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |
return np.isclose(a, b, rtol, atol, equal_nan) | |
def values(): | |
rtol = 1e-5 | |
atol = 1e-8 | |
arr = np.array([100, 1000]) | |
aran = np.arange(8).reshape((2, 2, 2)) | |
kw = {'rtol': rtol, 'atol': atol} | |
yield 1e10, 1.00001e10, {} | |
yield 1e10, np.nan, {} | |
yield [1e-8, 1e-7], [0.0, 0.0], {} | |
yield [1e10, 1e-7], [1.00001e10, 1e-8], {} | |
yield [1e10, 1e-8], [1.00001e10, 1e-9], {} | |
yield [1e10, 1e-8], [1.0001e10, 1e-9], {} | |
yield [1.0, np.nan], [1.0, np.nan], {} | |
yield [1.0, np.nan],[1.0, np.nan], {'equal_nan': True} | |
yield [np.nan, np.nan], [1.0, np.nan], {'equal_nan': True} | |
yield [1e-100, 1e-7], [0.0, 0.0], {'atol': 0.0} | |
yield [1e-10, 1e-10], [1e-20, 0.0], {} | |
yield [1e-10, 1e-10], [1e-20, 0.999999e-10], {'atol': 0.0} | |
yield [1, np.inf, 2], [3, np.inf, 4], kw | |
# tests taken from | |
# https://github.com/numpy/numpy/blob/aac965af6032b69d5cb515ad785cc9a331e816f4/numpy/core/tests/test_numeric.py#L2298-L2335 # noqa: E501 | |
# all close tests | |
yield [0, 1], [1, 0], kw | |
yield arr, arr, kw | |
yield [1], [1 + rtol + atol], kw | |
yield arr, arr + arr * rtol, kw | |
yield arr, arr + arr * rtol + atol, kw | |
yield aran, aran + aran * rtol, kw | |
yield np.inf, np.inf, kw | |
yield -np.inf, np.inf, kw | |
yield np.inf, [np.inf], kw | |
yield [np.inf, -np.inf], [np.inf, -np.inf], kw | |
# none close tests | |
yield [np.inf, 0], [1, np.inf], kw | |
yield [np.inf, -np.inf], [1, 0], kw | |
yield [np.inf, np.inf], [1, -np.inf], kw | |
yield [np.inf, np.inf], [1, 0], kw | |
yield [np.nan, 0], [np.nan, -np.inf], kw | |
yield [atol * 2], [0], kw | |
yield [1], [1 + rtol + atol * 2], kw | |
yield aran, aran + rtol * 1.1 * aran + atol * 1.1, kw | |
yield np.array([np.inf, 1]), np.array([0, np.inf]), kw | |
# some close tests | |
yield [np.inf, 0], [np.inf, atol * 2], kw | |
yield [atol, 1, 1e6 * (1 + 2 * rtol) + atol], [0, np.nan, 1e6], kw | |
yield np.arange(3), [0, 1, 2.1], kw | |
yield np.nan, [np.nan, np.nan, np.nan], kw | |
yield [0], [atol, np.inf, -np.inf, np.nan], kw | |
yield 0, [atol, np.inf, -np.inf, np.nan], kw | |
pyfunc = isclose | |
cfunc = njit(isclose) | |
# warmup | |
for a, b, kwargs in values(): | |
a = np.asarray(a) | |
b = np.asarray(b) | |
cfunc(a, b, **kwargs) | |
@pytest.mark.parametrize('a, b, kwargs', values()) | |
def test_foo(a, b, kwargs): | |
a = np.asarray(a) | |
b = np.asarray(b) | |
cfunc(a, b, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment