This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An implementation of "Machine Learning on Sequential Data Using a Recurrent Weighted Average" using pytorch | |
# https://arxiv.org/pdf/1703.01253.pdf | |
# | |
# | |
# This is a RNN (recurrent neural network) type that uses a weighted average of values seen in the past, rather | |
# than a separate running state. | |
# | |
# Check the test code at the bottom for an example of usage, where you can compare it's performance | |
# against LSTM and GRU, at a classification task from the paper. It handily beats both the LSTM and | |
# GRU :) |