Created
September 20, 2018 23:11
-
-
Save thomelane/a8507f13b298402c4f0ae8b3b48e1396 to your computer and use it in GitHub Desktop.
[Convolutions on Medium] Used in Medium blog post series #python #convolutions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def apply_conv(data, kernel, conv): | |
""" | |
Args: | |
data (NDArray): input data. | |
kernel (NDArray): convolution's kernel parameters. | |
conv (Block): convolutional layer. | |
Returns: | |
NDArray: output data (after applying convolution). | |
""" | |
# add dimensions for batch and channels if necessary | |
while data.ndim < len(conv.weight.shape): | |
data = data.expand_dims(0) | |
# add dimensions for channels and in_channels if necessary | |
while kernel.ndim < len(conv.weight.shape): | |
kernel = kernel.expand_dims(0) | |
# check if transpose convolution | |
if type(conv).__name__.endswith("Transpose"): | |
in_channel_idx = 0 | |
else: | |
in_channel_idx = 1 | |
# initialize and set weight | |
conv._in_channels = kernel.shape[in_channel_idx] | |
conv.initialize() | |
conv.weight.set_data(kernel) | |
return conv(data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment