Last active
December 24, 2017 15:31
-
-
Save davidhughhenrymack/6667b6bd14f950a14334cbd54954a857 to your computer and use it in GitHub Desktop.
A proposal for adding type-based symbolic shapes to keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is a very rough early draft of something I think would help speed up my coding | |
# in Keras. I spend a reasonable amount of time reading the source code to work out | |
# how X method treats different dimensions, which largely seems to be by convention and | |
# only semi-documented. As someone new to these libraries, it'd help me a lot to make this | |
# explicit | |
# The rough idea is two fold: | |
# - Have "nanotypes" representing commonly used dimensions e.g. batch size | |
# - Allow Dimensions, and Shapes (a list of dimensions) to be type enforced | |
# - Let the programmer specify how operations are to be performed in terms of dimension mapping | |
# I'm new to python and keras, and am still thinking out the practical implementation of this, | |
# so apologies for all my gross mistakes! | |
from keras.dimensions import BatchSize, Dim, NewDim | |
from keras.shapes import Shape, NewShape | |
from keras.layer import * | |
import keras.backend as K | |
# For some imaginary RNN network | |
SequenceLength = NewDim('SequenceLength') | |
WordLength = NewDim('WordLength') | |
RNNInputShape = Shape[BatchSize, SequenceLength, WordLength] | |
data = Input<RNNInputShape>(name='input') | |
def mySpecialRNN<Output_Width:Dim>(input:RNNInputShape): | |
rnn = SimpleRNN<Output_Width>() | |
n = rnn(input) | |
# ... | |
n = Reshape(Shape[SequenceLength, HalfWord, 2])(n) | |
def combine_some_tensors(x:MinShape[BatchSize, HalfWord], y: MinShape[BatchSize, HalfWord, _], z): | |
# Reshape, with type checking so we know it's possible | |
y = K.reshape[BatchSize, _, HalfWord](y) | |
# ... later, we have some more tensors to play with and are definin | |
# I'd like to multiply along just the HalfWord axis and scan the BatchSize axis | |
r = K.batch_dot[[BatchSize], HalfWord](x,y) | |
# Find the mean along the batch axis please | |
r = K.mean[BatchSize](r) | |
return r | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment