Last active
March 15, 2021 11:42
-
-
Save iCorv/1ad195b9510c2a3918506580af5a4adf to your computer and use it in GitHub Desktop.
A Keras Subpixel1D layer for upsampling audio and other time-series data in neural networks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Subpixel1D(tf.keras.layers.Layer): | |
def __init__(self, | |
r, | |
**kwargs): | |
super(Subpixel1D, self).__init__(**kwargs) | |
self.r = r | |
def build(self, input_shape): | |
# check if channels are evenly divisible for subpixel1d to work! | |
input_shape = tf.TensorShape(input_shape).as_list() | |
if input_shape[2] % self.r != 0: | |
raise ValueError( | |
f'The number of input channels must be evenly divisible by the upsampling ' | |
f'factor r. Received r={self.r}, but the input has {input_shape[2]} channels ' | |
f'(full input shape is {input_shape}).' | |
) | |
def call(self, inputs): | |
# (batch, samples, channels) -> (channels, samples, batch) | |
outputs = tf.transpose(inputs, [2, 1, 0]) | |
# (channels, samples, batch) -> (channels/r, r*samples, batch) | |
outputs = tf.batch_to_space(outputs, [self.r], [[0, 0]]) | |
# (channels, samples, batch) -> (batch, samples, channels) | |
outputs = tf.transpose(outputs, [2, 1, 0]) | |
return outputs | |
def compute_output_shape(self, input_shape): | |
input_shape = tf.TensorShape(input_shape).as_list() | |
return (input_shape[0], input_shape[1] * self.r, input_shape[2] // self.r) | |
def get_config(self): | |
config = { | |
'r': self.r, | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment