Created
January 29, 2021 03:55
-
-
Save guilhermeleobas/2c79e1dee40f779bfc5c522fb80e735f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import njit, extending, types | |
from numba.types import int64, double | |
from numba.core import funcdesc, sigutils | |
from llvmlite import ir | |
class External: | |
def __init__(self, symbol, signatures): | |
self.symbol = symbol | |
self.signatures = [] | |
if isinstance(signatures, str): | |
signatures = [signatures] | |
self._normalize_signatures(signatures) | |
def _normalize_signatures(self, signatures): | |
for sig in signatures: | |
argtys, retty = sigutils.normalize_signature(sig) | |
fnty = retty(*argtys) | |
self.signatures.append(fnty) | |
def _type_infer(self): | |
from numba.core.typing.templates import (make_concrete_template, | |
infer_global, infer) | |
template = make_concrete_template( | |
self.symbol, key=self.symbol, signatures=self.signatures) | |
infer(template) | |
infer_global(self, types.Function(template)) | |
def _lower_external_call(self): | |
def codegen(context, builder, sig, args): | |
fndesc = funcdesc.ExternalFunctionDescriptor( | |
self.symbol, sig.return_type, sig.args) | |
func = context.declare_external_function( | |
builder.module, fndesc) | |
return builder.call(func, args) | |
for sig in self.signatures: | |
extending.lower_builtin(self.symbol, *sig.args)(codegen) | |
def register(self): | |
self._type_infer() | |
self._lower_external_call() | |
def external(*args): | |
# Make inner function for the actual work | |
def decorate(func): | |
name = getattr(func, '__name__', str(func)) | |
llc = External(name, args) | |
llc.register() | |
return llc | |
return decorate | |
@external('double(double)') | |
def log10(a): | |
pass | |
@external('double(double)') | |
def log2(a): | |
pass | |
@external('int64(int64)') | |
def abs(a): | |
pass | |
@njit('double(double)') | |
def foo(a): | |
return log10(a) + log2(a) + abs(-321) | |
print(foo(1000.0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment