Skip to content

Instantly share code, notes, and snippets.

@soravux
Created March 3, 2017 20:16
Show Gist options
  • Save soravux/791206a72e482f506008a80c73614af5 to your computer and use it in GitHub Desktop.
Save soravux/791206a72e482f506008a80c73614af5 to your computer and use it in GitHub Desktop.
ResizeImage layer
import tensorflow as tf
# To use : model.add(ResizeImage(<an input shape>))
# e.g.: model.add(ResizeImage(model.layers[0].output_shape[1:3]))
class ResizeImage(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super().__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
super().build(input_shape)
def call(self, x, mask=None):
return tf.image.resize_images(x, self.output_dim)
def get_output_shape_for(self, input_shape):
return (input_shape[0], self.output_dim[0], self.output_dim[1], input_shape[3])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment