Skip to content

Instantly share code, notes, and snippets.

@nihalkenkre
Created February 16, 2023 13:05
Show Gist options
  • Save nihalkenkre/e51619241de182d8de07acc277787182 to your computer and use it in GitHub Desktop.
Save nihalkenkre/e51619241de182d8de07acc277787182 to your computer and use it in GitHub Desktop.
Forward Mode Automatic Differentiation Using Python's Abstract Syntax Tree
import inspect
import numpy as np
import ast
import copy
class Variable:
def __init__(self, name: str = '', value: float = 0.0):
self.name = name
self.value = value
def __repr__(self):
return 'Name: %s, Type: %s, Value: %s, Value Type: %s' % (self.name, self.__class__.__name__, self.value, type(self.value))
def __add__(self, other):
if isinstance(other, self.__class__):
value = self.value + other.value
return Variable(self.name + '_A_' + other.name, value)
else:
value = self.value + other
return Variable(self.name + '_A_' + str(other), value)
def __radd__(self, other):
if isinstance(other, self.__class__):
value = other.value + self.value
return Variable(other.name + '_A_' + self.name, value)
else:
value = other + self.value
return Variable(str(other) + '_A_' + self.name, value)
def __sub__(self, other):
if isinstance(other, self.__class__):
value = self.value - other.value
return Variable(self.name + '_S_' + other.name, value)
else:
value = self.value - other
return Variable(self.name + '_S_' + str(other), value)
def __rsub__(self, other):
if isinstance(other, self.__class__):
value = other.value - self.value
return Variable(other.name + '_S_' + self.name, value)
else:
value = other - self.value
return Variable(str(other) + '_S_' + self.name, value)
def __mul__(self, other):
if isinstance(other, self.__class__):
value = self.value * other.value
return Variable(self.name + '_M_' + other.name, value)
else:
value = self.value * other
return Variable(self.name + '_M_' + str(other), value)
def __rmul__(self, other):
if isinstance(other, self.__class__):
value = other.value * self.value
return Variable(other.name + '_M_' + self.name, value)
else:
value = self.value * other
return Variable(str(other) + '_M_' + self.name, value)
def __div__(self, other):
if isinstance(other, self.__class__):
value = self.value / other.value
return Variable(self.name + '_D_' + other.name, value)
else:
value = self.value / other
return Variable(self.name + '_D_' + str(other), value)
def __rdiv__(self, other):
if isinstance(other, self.__class__):
value = other.value / self.value
return Variable(other.name + '_D_' + self.name, value)
else:
value = other / self.value
return Variable(str(other) + '_D_' + self.name, value)
def __truediv__(self, other):
if isinstance(other, self.__class__):
value = self.value / other.value
return Variable(self.name + '_D_' + other.name, value)
else:
value = self.value / other
return Variable(self.name + '_D_' + str(other), value)
def __rtruediv__(self, other):
if isinstance(other, self.__class__):
value = other.value / self.value
return Variable(other.name + '_D_' + self.name, value)
else:
value = other / self.value
return Variable(str(other) + '_D_' + self.name, value)
def __pow__(self, other):
if isinstance(other, self.__class__):
value = self.value ** other.value
return Variable(self.name + '_P_' + other.name, value)
else:
value = self.value ** other
return Variable(self.name + '_P_' + str(other), value)
def __rpow__(self, other):
if isinstance(other, self.__class__):
value = other.value ** self.value
return Variable(other.name + '_P_' + self.name, value)
else:
value = other ** self.value
return Variable(str(other) + '_P_' + self.name, value)
def __neg__(self):
return Variable(self.name, value=-self.value)
def log(self):
return Variable(self.name, value=np.log(self.value))
def exp(self):
return Variable(self.name, value=np.exp(self.value))
def sin(self):
return Variable(self.name, value=np.sin(self.value))
def cos(self):
return Variable(self.name, value=np.cos(self.value))
class ForwardModeAutoDiffAST:
def __init__(self, func, df_v_names=[]):
self.__var_names = []
self.__df_v_names = copy.deepcopy(df_v_names)
self.__flattened_func = self.__flatten_func(func)
self.__func_and_diff = self.__create_df_func(df_v_names)
exec(compile(ast.parse(ast.unparse(self.__func_and_diff)),
filename='<string>', mode='exec'), globals())
#print(ast.unparse(self.__flattened_func))
#print(ast.unparse(self.__func_and_diff))
def __generate_random_name(self, chars_count: int = 3, var_names=[]) -> str:
letters = 'abcdefghijklmnopqrstuvwxyz'
var_name = ''
for i in range(chars_count):
rand_int = np.random.randint(0, 26)
var_name += letters[rand_int]
if var_name in var_names:
self.__generate_random_name(chars_count)
var_names.append(var_name)
return var_name
def __flatten_func(self, func, add_print_stmt=False):
function_to_flatten = func
function_to_flatten_ast = ast.parse(
inspect.getsource(function_to_flatten)).body[0]
df_assigns = []
return_node: ast.Return
for node in ast.walk(function_to_flatten_ast):
if isinstance(node, ast.Assign):
if isinstance(node.value, ast.BinOp):
assigns = self.__binop_to_assign(node.value)
node.value = assigns[len(assigns) - 1].targets[0]
df_assigns.extend(assigns)
df_assigns.append(node)
elif isinstance(node.value, ast.UnaryOp):
assign = self.__unop_to_assign(node.value)
node.value = assign.targets[0]
df_assigns.append(assign)
df_assigns.append(node)
else:
df_assigns.append(node)
elif isinstance(node, ast.Return):
return_node = node
df_body = []
for assign in df_assigns:
df_body.append(assign)
var_name_attr = ast.Attribute(
value=ast.Name(assign.targets[0].id), attr='name')
var_name_value = ast.Constant(assign.targets[0].id)
var_name_assign = ast.Assign(
targets=[var_name_attr], value=var_name_value, lineno=1)
df_body.append(var_name_assign)
if add_print_stmt:
print_call_ast = ast.Call(func=ast.Name(
id='print'), args=[ast.Constant(assign.targets[0].id), assign.targets[0]], keywords=[])
print_expr_ast = ast.Expr(print_call_ast)
df_body.append(print_expr_ast)
df_body.append(return_node)
df_function_args = function_to_flatten_ast.args
flattened_function = ast.FunctionDef(
func.__name__ + '_flat', args=df_function_args, body=df_body, decorator_list=[], lineno=1)
return flattened_function
def __create_df_assign(self, assign_for_df: ast.Assign, df_v_names: list[str] = []) -> ast.Assign:
if isinstance(assign_for_df.value, ast.BinOp):
binop = assign_for_df.value
if isinstance(binop.left, ast.Name) and isinstance(binop.right, ast.Name):
g = binop.left
h = binop.right
if g.id in df_v_names and h.id in df_v_names:
if isinstance(binop.op, ast.Add):
value = ast.BinOp(left=ast.Name(
id='d'+g.id), op=ast.Add(), right=ast.Name(id='d'+h.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(
targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Sub):
value = ast.BinOp(left=ast.Name(
id='d'+g.id), op=ast.Sub(), right=ast.Name(id='d'+h.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(
targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Mult):
value_left = ast.BinOp(
left=g, op=ast.Mult(), right=ast.Name(id='d'+h.id))
value_op = ast.Add()
value_right = ast.BinOp(left=ast.Name(
id='d'+g.id), op=ast.Mult(), right=h)
value = ast.BinOp(
left=value_left, op=value_op, right=value_right)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Div):
dg = ast.Name(id='d'+g.id)
dh = ast.Name(id='d'+h.id)
dg_M_h = ast.BinOp(left=dg, op=ast.Mult(), right=h)
dh_M_g = ast.BinOp(left=dh, op=ast.Mult(), right=g)
dg_M_h_M_dh_M_g = ast.BinOp(
left=dg_M_h, op=ast.Sub(), right=dh_M_g)
h_P_2 = ast.BinOp(left=h, op=ast.Pow(),
right=ast.Constant(value=2))
dg_M_h_M_dh_M_g_D_h_P_2 = ast.BinOp(
left=dg_M_h_M_dh_M_g, op=ast.Div(), right=h_P_2)
value = dg_M_h_M_dh_M_g_D_h_P_2
target = ast.Name('d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif g.id in df_v_names and h.id not in df_v_names:
if isinstance(binop.op, ast.Add):
value = ast.Name(id='d'+g.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Sub):
value = ast.Name(id='d'+g.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Mult):
value = ast.BinOp(left=ast.Name(
id='d'+g.id), op=ast.Mult(), right=h)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Div):
binop_left = ast.BinOp(left=ast.Constant(
value=1), op=ast.Div(), right=h)
value = ast.BinOp(
left=binop_left, op=ast.Mult(), right=ast.Name('d'+g.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif g.id not in df_v_names and h.id in df_v_names:
if isinstance(binop.op, ast.Add):
value = ast.Name(id='d'+h.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Sub):
value = ast.UnaryOp(
op=ast.USub(), operand=ast.Name(id='d'+h.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Mult):
value = ast.BinOp(left=ast.Name(
id='d'+h.id), op=ast.Mult(), right=g)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Div):
binop_right = ast.BinOp(
left=h, op=ast.Pow(), right=ast.Constant(value=2))
value = ast.BinOp(left=ast.UnaryOp(
op=ast.USub(), operand=g), op=ast.Div(), right=binop_right)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.left, ast.Name) and isinstance(binop.right, ast.Constant):
if binop.left.id in df_v_names:
if isinstance(binop.op, ast.Add):
value = ast.Name(id='d'+binop.left.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Sub):
value = ast.Name(id='d'+binop.left.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Mult):
value = ast.BinOp(left=binop.right, op=ast.Mult(),
right=ast.Name(id='d'+binop.left.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Div):
binop_left = ast.BinOp(left=ast.Constant(
value=1), op=ast.Div(), right=binop.right)
value = ast.BinOp(left=binop_left, op=ast.Mult(
), right=ast.Name(id='d'+binop.left.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Pow):
binop_left = binop.right
binop_right = ast.BinOp(
left=binop.left, op=ast.Pow(), right=ast.Constant(value=binop.right.value - 1))
left_value = ast.BinOp(
left=binop_left, op=ast.Mult(), right=binop_right)
value = ast.BinOp(left=left_value, op=ast.Mult(
), right=ast.Name(id='d'+binop.left.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.left, ast.Constant) and isinstance(binop.right, ast.Name):
if binop.right.id in df_v_names:
if isinstance(binop.op, ast.Add):
value = ast.Name(id='d'+binop.right.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Sub):
value = ast.UnaryOp(
op=ast.USub(), operand=ast.Name(id='d'+binop.right.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Mult):
value = ast.BinOp(left=binop.left, op=ast.Mult(),
right=ast.Name(id='d'+binop.right.id))
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Div):
binop_right = ast.BinOp(
left=binop.right, op=ast.Pow(), right=ast.Constant(value=2))
value = ast.BinOp(left=ast.UnaryOp(
op=ast.USub(), operand=binop.left), op=ast.Div(), right=binop_right)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(binop.op, ast.Pow):
binop_left = ast.Call(func=ast.Attribute(
value=ast.Name(id='np'), attr='log'), args=[binop.left], keywords=[])
value = ast.BinOp(
left=binop_left, op=ast.Mult(), right=binop)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
elif isinstance(assign_for_df.value, ast.Call):
...
elif isinstance(assign_for_df.value, ast.Name):
if assign_for_df.value.id in df_v_names:
value = ast.Name(id='d'+assign_for_df.value.id)
target = ast.Name(id='d'+assign_for_df.targets[0].id)
return ast.Assign(targets=[target], value=value, lineno=1)
def __create_df_func(self, df_v_names=[], add_print_stmt=False):
function_and_df = copy.deepcopy(self.__flattened_func)
function_and_df.name += '_and_diff'
for df_v_name in df_v_names:
arg_name = 'd' + df_v_name
function_and_df.args.args.append(ast.arg(arg_name))
function_and_df.args.defaults.append(ast.Constant(value=0))
return_node = function_and_df.body.pop(len(function_and_df.body) - 1)
df_assigns = []
for node in ast.walk(self.__flattened_func):
if isinstance(node, ast.Assign):
df_assign = self.__create_df_assign(node, df_v_names)
if df_assign is not None:
df_assigns.append(df_assign)
df_v_names.append(node.targets[0].id)
if add_print_stmt:
print_call_ast = ast.Call(func=ast.Name(
id='print'), args=[ast.Constant(df_assign.targets[0].id), df_assign.targets[0]], keywords=[])
print_expr_ast = ast.Expr(print_call_ast)
df_assigns.append(print_expr_ast)
return_df_elts = [return_node.value]
for df_ass in df_assigns:
if isinstance(df_ass, ast.Assign):
if df_ass.targets[0].id == 'd' + return_node.value.id:
return_df_elts.append(df_ass.targets[0])
return_df = ast.Return(value=ast.Tuple(return_df_elts))
function_and_df.body.extend(df_assigns)
function_and_df.body.append(return_df)
return function_and_df
def __unop_to_assign(self, node: ast.UnaryOp) -> ast.Assign:
if isinstance(node.op, ast.UAdd):
value = ast.BinOp(left=ast.Constant(value=1),
op=ast.Mult(), right=node.operand)
target = ast.Name(id=self.__generate_random_name(
var_names=self.__var_names))
assign = ast.Assign(targets=[target], value=value, lineno=1)
return assign
elif isinstance(node.op, ast.USub):
value = ast.BinOp(left=ast.Constant(value=-1),
op=ast.Mult(), right=node.operand)
target = ast.Name(id=self.__generate_random_name(
var_names=self.__var_names))
assign = ast.Assign(targets=[target], value=value, lineno=1)
return assign
def __binop_to_assign(self, node: ast.BinOp) -> ast.Assign:
left_assigns: list[ast.Assign] = []
right_assigns: list[ast.Assign] = []
assigns = []
if isinstance(node.left, ast.BinOp):
left_assigns = self.__binop_to_assign(node.left)
assigns.extend(left_assigns)
if isinstance(node.right, ast.BinOp):
right_assigns = self.__binop_to_assign(node.right)
assigns.extend(right_assigns)
if isinstance(node.left, ast.UnaryOp):
left_assign = self.__unop_to_assign(node.left)
left_assigns.append(left_assign)
assigns.extend(left_assigns)
if isinstance(node.right, ast.UnaryOp):
right_assign = self.__unop_to_assign(node.right)
right_assigns.append(right_assign)
assigns.extend(right_assigns)
target_name_id = ''
if len(left_assigns) == 0:
if isinstance(node.left, ast.Name):
target_name_id += node.left.id + '_'
elif isinstance(node.left, ast.Constant):
target_name_id += self.__generate_random_name(
var_names=self.__var_names) + '_'
else:
target_name_id += left_assigns[len(left_assigns) -
1].targets[0].id + '_'
target_name_id += node.op.__class__.__name__.lower() + '_'
if len(right_assigns) == 0:
if isinstance(node.right, ast.Name):
target_name_id += node.right.id
elif isinstance(node.right, ast.Constant):
target_name_id += self.__generate_random_name(
var_names=self.__var_names)
else:
target_name_id += right_assigns[len(right_assigns) -
1].targets[0].id
target = ast.Name(id=target_name_id)
if len(left_assigns) > 0:
node.left = left_assigns[len(left_assigns) - 1].targets[0]
if len(right_assigns) > 0:
node.right = right_assigns[len(right_assigns) - 1].targets[0]
assign = ast.Assign(targets=[target], value=node, lineno=1)
assigns.append(assign)
return assigns
def __get_gradient_for_df_v_name(self, func_str, df_v_name, **kwargs):
func_str += 'd' + df_v_name + '=1)'
globals()['L'] = 0
globals()['dL_' + df_v_name] = 0
globals()['L'], globals()[
'dL_' + df_v_name] = eval(func_str, globals(), kwargs)
return [globals()['L'], globals()['dL_' + df_v_name]]
def get_gradients(self, **kwargs):
gradients = {}
for df_v_name in self.__df_v_names:
func_str = self.__func_and_diff.name + '('
for key, value, in kwargs.items():
func_str += key + '=' + key + ', '
df_v_name_grad = self.__get_gradient_for_df_v_name(
func_str, df_v_name, **kwargs)
gradients[df_v_name] = df_v_name_grad
return gradients
import ast
import numpy as np
import matplotlib.pyplot as plt
def train_loop(X, y, w, b):
y_pred = w * X + b
loss = (y_pred - y) ** 2
return loss
def predict(X, w, b):
return w * X + b
def main():
w = Variable(name='w', value=np.random.rand(1))
b = Variable(name='b', value=np.random.rand(1))
X_train = Variable(name='X', value=np.random.rand(1) * -100)
y_train = Variable(name='y', value=np.random.rand(1) * 100)
X_test = Variable(name='X', value=np.random.rand(1) * 100)
y_test = Variable(name='y', value=np.random.rand(1) * 100)
z = Variable(name='z', value=np.random.rand(1)) # TEST
x = Variable(name='x', value=np.random.rand(1)) # TEST
print(
f'X={X_train.value} y_pred={predict(X_train, w, b).value}, y={y_train.value}')
epochs = 100
learning_rate = 0.0001
train_losses = []
fm_ad = ForwardModeAutoDiffAST(
train_loop, [w.name, b.name]) # , X=X, y=y, w=w, b=b)
for epoch in range(epochs):
gradients = fm_ad.get_gradients(X=X_train, y=y_train, w=w, b=b)
w = w - gradients[w.name][1].value * learning_rate
w.name = 'w'
b = b - gradients[b.name][1].value * learning_rate
b.name = 'b'
train_losses.append(gradients[w.name][0].value)
if epoch % (epochs/10) == 0:
print(f'Train Loss at {epoch} = {gradients[w.name][0].value}')
plt.plot(range(epochs), train_losses, label='Train Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
print(
f'X={X_train.value} y_pred={predict(X_train, w, b).value}, y={y_train.value}')
if __name__ == '__main__':
main()
@nihalkenkre
Copy link
Author

This is an unoptimized, uncommented working first version.

Please visit the blog [here] (https://dev.to/nihalkenkre/autodiff-and-python-ast-part-1-n86)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment