Skip to content

Instantly share code, notes, and snippets.

@guilhermeleobas
Created January 29, 2021 03:55
Show Gist options
  • Save guilhermeleobas/2c79e1dee40f779bfc5c522fb80e735f to your computer and use it in GitHub Desktop.
Save guilhermeleobas/2c79e1dee40f779bfc5c522fb80e735f to your computer and use it in GitHub Desktop.
from numba import njit, extending, types
from numba.types import int64, double
from numba.core import funcdesc, sigutils
from llvmlite import ir
class External:
def __init__(self, symbol, signatures):
self.symbol = symbol
self.signatures = []
if isinstance(signatures, str):
signatures = [signatures]
self._normalize_signatures(signatures)
def _normalize_signatures(self, signatures):
for sig in signatures:
argtys, retty = sigutils.normalize_signature(sig)
fnty = retty(*argtys)
self.signatures.append(fnty)
def _type_infer(self):
from numba.core.typing.templates import (make_concrete_template,
infer_global, infer)
template = make_concrete_template(
self.symbol, key=self.symbol, signatures=self.signatures)
infer(template)
infer_global(self, types.Function(template))
def _lower_external_call(self):
def codegen(context, builder, sig, args):
fndesc = funcdesc.ExternalFunctionDescriptor(
self.symbol, sig.return_type, sig.args)
func = context.declare_external_function(
builder.module, fndesc)
return builder.call(func, args)
for sig in self.signatures:
extending.lower_builtin(self.symbol, *sig.args)(codegen)
def register(self):
self._type_infer()
self._lower_external_call()
def external(*args):
# Make inner function for the actual work
def decorate(func):
name = getattr(func, '__name__', str(func))
llc = External(name, args)
llc.register()
return llc
return decorate
@external('double(double)')
def log10(a):
pass
@external('double(double)')
def log2(a):
pass
@external('int64(int64)')
def abs(a):
pass
@njit('double(double)')
def foo(a):
return log10(a) + log2(a) + abs(-321)
print(foo(1000.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment