Last active
June 5, 2020 13:08
-
-
Save hugohadfield/e238a139d13cd0dd4e17bb39b7b577bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba | |
import ast | |
import astpretty | |
import time | |
import inspect | |
from numba.extending import overload | |
from numba import types | |
import numpy as np | |
from clifford.g3c import * | |
gmt_func = layout.gmt_func | |
omt_func = layout.omt_func | |
imt_func = layout.imt_func | |
e1_val = e1.value | |
e2_val = e2.value | |
# Get a scalar promotion function | |
def get_as_ga_func(layout): | |
ndims = layout.gaDims | |
@numba.njit | |
def as_ga(x): | |
op = np.zeros(ndims) | |
op[0] = x | |
return op | |
return as_ga | |
as_ga = get_as_ga_func(layout) | |
def ga_add(x): | |
# dummy function to overload | |
pass | |
@overload(ga_add, inline='always') | |
def ol_ga_add(a, b): | |
if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): | |
def impl(a, b): | |
op = b.astype(np.float32) | |
op[0] += a | |
return op | |
return impl | |
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): | |
def impl(a, b): | |
op = a.astype(np.float32) | |
op[0] += b | |
return op | |
return impl | |
else: | |
def impl(a, b): | |
return a + b | |
return impl | |
def ga_sub(x): | |
# dummy function to overload | |
pass | |
@overload(ga_sub, inline='always') | |
def ol_ga_sub(a, b): | |
if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): | |
def impl(a, b): | |
op = -b.astype(np.float32) | |
op[0] += a | |
return op | |
return impl | |
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): | |
def impl(a, b): | |
op = a.astype(np.float32) | |
op[0] -= b | |
return op | |
return impl | |
else: | |
def impl(a, b): | |
return a - b | |
return impl | |
def ga_mul(x): | |
# dummy function to overload | |
pass | |
@overload(ga_mul, inline='always') | |
def ol_ga_mul(a, b): | |
if isinstance(a, types.Array) and isinstance(b, types.Array): | |
def impl(a, b): | |
return gmt_func(a, b) | |
return impl | |
else: | |
def impl(a, b): | |
return a*b | |
return impl | |
def ga_xor(x): | |
# dummy function to overload | |
pass | |
@overload(ga_xor, inline='always') | |
def ol_ga_xor(a, b): | |
if isinstance(a, types.Array) and isinstance(b, types.Array): | |
def impl(a, b): | |
return omt_func(a, b) | |
return impl | |
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): | |
def impl(a, b): | |
return omt_func(a, as_ga(b)) | |
return impl | |
elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array) : | |
def impl(a, b): | |
return omt_func(as_ga(a), b) | |
return impl | |
else: | |
def impl(a, b): | |
return a^b | |
return impl | |
def ga_or(x): | |
# dummy function to overload | |
pass | |
@overload(ga_or, inline='always') | |
def ol_ga_or(a, b): | |
if isinstance(a, types.Array) and isinstance(b, types.Array): | |
def impl(a, b): | |
return imt_func(a, b) | |
return impl | |
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): | |
def impl(a, b): | |
return imt_func(a, as_ga(b)) | |
return impl | |
elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array) : | |
def impl(a, b): | |
return imt_func(as_ga(a), b) | |
return impl | |
else: | |
def impl(a, b): | |
return a|b | |
return impl | |
class jit_func(object): | |
def __init__(self, ast_debug=False): | |
self.ast_debug = ast_debug | |
def __call__(self, func): | |
# Get the function source | |
fname = func.__name__ | |
source = inspect.getsource(func) | |
source = '\n'.join(source.splitlines()[1:]) # remove the decorator first line. | |
# Re-write the ast | |
tree = ast.parse(source) | |
if self.ast_debug: | |
print('\n\n\n\n TRANFORMING FROM \n\n\n\n') | |
astpretty.pprint(tree) | |
tree = GATransformer().visit(tree) | |
ast.fix_missing_locations(tree) | |
if self.ast_debug: | |
print('\n\n\n\n TRANFORMING TO \n\n\n\n') | |
astpretty.pprint(tree) | |
# Compile the function | |
co = compile(tree, '<ast>', "exec") | |
locals = {} | |
exec(co, globals(), locals) | |
new_func = locals[fname] | |
# JIT the function | |
jit_func = numba.njit(new_func) | |
# Wrap the jitted function | |
def wrapper(*args, **kwargs): | |
return layout.MultiVector(value=jit_func(*[a.value for a in args], **kwargs)) | |
return wrapper | |
class GATransformer(ast.NodeTransformer): | |
def visit_BinOp(self, node): | |
if isinstance(node.op, ast.Mult): | |
new_node = ast.Call( | |
func = ast.Name(id='ga_mul', ctx=ast.Load()), | |
args = [node.left, node.right], | |
keywords = [] | |
) | |
new_node = GATransformer().visit(new_node) | |
return new_node | |
elif isinstance(node.op, ast.BitXor): | |
new_node = ast.Call( | |
func = ast.Name(id='ga_xor', ctx=ast.Load()), | |
args = [node.left, node.right], | |
keywords = [] | |
) | |
new_node = GATransformer().visit(new_node) | |
return new_node | |
elif isinstance(node.op, ast.BitOr): | |
new_node = ast.Call( | |
func = ast.Name(id='ga_or', ctx=ast.Load()), | |
args = [node.left, node.right], | |
keywords = [] | |
) | |
new_node = GATransformer().visit(new_node) | |
return new_node | |
elif isinstance(node.op, ast.Add): | |
new_node = ast.Call( | |
func = ast.Name(id='ga_add', ctx=ast.Load()), | |
args = [node.left, node.right], | |
keywords = [] | |
) | |
new_node = GATransformer().visit(new_node) | |
return new_node | |
elif isinstance(node.op, ast.Sub): | |
new_node = ast.Call( | |
func = ast.Name(id='ga_sub', ctx=ast.Load()), | |
args = [node.left, node.right], | |
keywords = [] | |
) | |
new_node = GATransformer().visit(new_node) | |
return new_node | |
return node | |
@jit_func(ast_debug=True) | |
def test_func(A, B, C): | |
op = (((A*B)*C)|(B^A)) - 3.1 - A - 7*B + 5 + C + 2.5 + (2^(A*B*C)^3) + (A|5) | |
return op | |
def slow_test_func(A, B, C): | |
op = (((A*B)*C)|(B^A)) - 3.1 - A - 7*B + 5 + C + 2.5 + (2^(A*B*C)^3) + (A|5) | |
return op | |
print(test_func(e1, e2, einf)) | |
print(slow_test_func(e1, e2, einf)) | |
nrepeats = 100000 | |
start_time = time.time() | |
for i in range(nrepeats): | |
test_func(e1, e2, einf) | |
end_time = time.time() | |
print(1E6*(end_time - start_time)/nrepeats) | |
nrepeats = 100000 | |
start_time = time.time() | |
for i in range(nrepeats): | |
slow_test_func(e1, e2, einf) | |
end_time = time.time() | |
print(1E6*(end_time - start_time)/nrepeats) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment