Created
October 4, 2017 08:19
-
-
Save InnerPeace-Wu/9a18ce69242b06877daabd7aae9743dd to your computer and use it in GitHub Desktop.
identity_initializer of tensorflow to initialize 2-D matrix.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
from tensorflow.python.framework import constant_op | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.ops import array_ops | |
from tensorflow.python.ops import linalg_ops | |
from tensorflow.python.ops.init_ops import Initializer | |
class identity_initializer(Initializer): | |
"""Initializer that generates tensors initialized to 0.""" | |
def __init__(self, dtype=dtypes.float32): | |
self.dtype = dtypes.as_dtype(dtype) | |
def __call__(self, shape, dtype=None, partition_info=None): | |
if dtype is None: | |
dtype = self.dtype | |
return linalg_ops.eye(shape[0], shape[1], dtype=dtype) | |
def get_config(self): | |
return {"dtype": self.dtype.name} | |
def identity(): | |
tf.reset_default_graph() | |
a = tf.get_variable('a', [3,4], tf.float32, identity_initializer()) | |
#another way to initialize variable to identity matrix | |
#a = tf.Variable(tf.eye(2, num_columns=3),tf.float32) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
print(sess.run(a)) | |
if __name == '__main__': | |
identity() | |
'''out: | |
[[ 1. 0. 0. 0.] | |
[ 0. 1. 0. 0.] | |
[ 0. 0. 1. 0.]] | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment