Skip to content

Instantly share code, notes, and snippets.

@guilhermeleobas
Created February 9, 2021 21:07
Show Gist options
  • Save guilhermeleobas/438cb16360fdab6bdcc72fb7f293bad2 to your computer and use it in GitHub Desktop.
Save guilhermeleobas/438cb16360fdab6bdcc72fb7f293bad2 to your computer and use it in GitHub Desktop.
from numba.core.typing import typeof, templates
from numba.core import extending, types
class ConstFunction:
@classmethod
def fromobject(cls, func):
name = getattr(func, '__name__', str(func))
return cls(func, name)
def __init__(self, func, name):
self.func = func
self.name = name
self.register()
def register(self):
class ConstFunctionTemplate(templates.AbstractTemplate):
obj = self
key = self.name
def generic(self, args, kws):
retval = self.obj.func() # eval the function
retty = typeof.typeof(retval, purpose=typeof.Purpose.constant)
# lowering
def codegen(context, builder, sig, args):
typ = context.get_value_type(retty)
return typ(retval)
extending.lower_builtin(self.key)(codegen)
return retty()
templates.infer(ConstFunctionTemplate)
templates.infer_global(self, types.Function(ConstFunctionTemplate))
constfunc = ConstFunction.fromobject
# usage
@constfunc
def IS_GPU():
target_info = TargetInfo()
if target_info.is_gpu:
return True
else:
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment