Created
July 7, 2018 15:46
-
-
Save davidhughhenrymack/ad9319d23276ffa608f6826e820c7d2c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def dynamic_assert_shape(tensor, shape): | |
""" | |
Check that a tensor has a shape given by a list of constants and tensor values. | |
This function will place an operation into your graph that gets executed at runtime. | |
This is helpful because often tensors have many dynamic sized dimensions that | |
you cannot otherwise compare / assert are as you expect. | |
For example, measure a dimension at run time: | |
`batch_size = tf.shape(my_tensor)[0]` | |
then assert another tensor does indeed have the right shape: | |
`other_tensor = dynamic_assert_shape(other_tensor, [batch_size, 16])` | |
You should use this as an inline identity function so that the operation it generates | |
gets added and executed in the graph | |
Returns: the argument `tensor` unchanged | |
""" | |
lhs = tf.shape(tensor) | |
rhs = tf.convert_to_tensor(shape, dtype=lhs.dtype) | |
assert_op = tf.assert_equal(lhs, rhs, message=f"Asserting shape of {tensor.name}") | |
with tf.control_dependencies([assert_op]): | |
return tf.identity(tensor, name="dynamic_assert_shape") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment