Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active March 27, 2019 06:52
Show Gist options
  • Save gngdb/81049f9796d1292b672999f1c1f21ab7 to your computer and use it in GitHub Desktop.
Save gngdb/81049f9796d1292b672999f1c1f21ab7 to your computer and use it in GitHub Desktop.
how to use Defun to define custom gradients in Tensorflow
import tensorflow as tf
from tensorflow.python.framework import function
@function.Defun()
def my_op_grad(op, grad): ### instead of my_op_grad(x)
return tf.sigmoid(op)
@function.Defun(grad_func=my_op_grad)
def my_op(a):
return tf.identity(a)
def main():
a = tf.Variable(tf.constant([-5., 4., -3., 2., 1.], dtype=tf.float32))
grad = tf.gradients(my_op(a), a)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
result = sess.run(grad)
print(result)
sess.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment