While continuing my study of neural networks and deep learning, I inevitably meet up with recurrent neural networks.
Recurrent neural networks (RNN) are a particular kind of neural networks usually very good at predicting sequences due to their inner working. If your task is to predict a sequence or a periodic signal, then using a RNN might be a good starting point. Plain vanilla RNN work fine but they have a little problem when trying to “keep in memory” events occured, say for instance, more than 20 steps back. The solution to this problem has been addressed with the development of a model called LSTM network. As far as I know, LSTM should usually be preferred to a plain vanilla RNN when possible as it yields better results.
In this post however, I am going to work on a plain vanilla RNN model. The reasons for doing this are two. First of all this is one of my first experience with RNN and I would like to get comfortable with them before going deeper; secondly, R provides a simple and very user friendly package named “rnn” for working with recurrent neural networks. I am going to dive in LSTM using MXNET and Tensorflow later.
Task description
The task I am going to address is trying to predict a cosine from a noisy sine wave. Here below you can see the plot of the predictor X sequence and the Y sequence to be predicted.
X is essentially a sine wave with some normally distributed noise, while Y is a straightforward smooth cosine wave.
You can clearly see that what I expect the model to do is to capture the phase shift of 90 degrees between the two waves and to throw away the noise in the input.
I chose to use a 5Hz frequency for both the waves but you can play around maybe trying to obtain similar results changing the frequency. Be aware though that the higher the frequency, the more datapoints you need to avoid problems that comes with the sampling theorem.
Preprocessing
The artificial dataset I created for this task is a set of 10 sequences each of which consists of 40 observations. The X matrix contains 10 sequences of a noisy sine wave while the Y matrix contains the corresponding 10 sequences of a clear cosine wave.
Before fitting the model I standardized all the data in the $[0 – 1]$ interval. When using any neural network model with real valued data, make sure not to avoid this step because if you do, then you might spend the next hour trying to figure out why the model did not converge or spitted out weird results. I am not an expert but I know from personal painful experience that this step is usually crucial, nevertheless I may occasionally forget to do it and then wander around like a fool looking for why I did not get what I was expecting
Model
As far as the model is concerned, I decided to use 16 hidden neurons, mostly because the other configurations that I tried all ended up with weird spikes in the valleys and peaks of the waves. This is the most notable problem I have encountered while trying to address this task: it is very easy to predict the upwards and downwards paths of the wave, while the peaks and valleys may raise some problems and be predicted as sudden spikes. 16 hidden units and about 1500 epochs seem to fix this problem.
Results
These below are the results I obtained after some experiments:
This is the full prediction for the entire predictor matrix X
While this one is the prediction for the test set
I would say they look pretty good. I encourage you to try and play with this to look for the limits of the model. For instance I tryied to double the frequency of the cosine wave to 10Hz and still, the predictions look pretty good. Below you can see the X sequence (no change here) and the doubled frequency Y sequence. The last plot shows the prediction on the testing set vs the real values.
The code I used for this simple experiment is showed in the gist below. In order to get the plots for the doubled frequency example just put $f = 10$.
Thank you for reading this post, I hope you've found it interesting and useful. If you have any question, please do leave a comment.
I'm curious, how to implement the same code using MXNET?
ReplyDeleteHi! As far as I know, MXNet provides the LSTM model implementation. You can find some examples in the official documentation here: http://mxnet.io/tutorials/nlp/rnn.html
DeleteThanks, I saw it already, but can’t figure out exactly how to run time-series forecasting with MXNet and LTSM…
ReplyDeleteMaybe in the future I'll try using it and I'll make a post about it.
DeleteTo be honest though, in mosts of the tests I made with time series, RNN or a simple deep neural network seemed to work fine straight out of the box without that much fine tuning. If you do not have any specific reason to use LSTM you can try these first and see if you can get good results.
This comment has been removed by the author.
ReplyDeleteHi Kanime, take a look at https://stats.stackexchange.com/ there's a community of people that may be able to help you with your questions. The people over there are very nice, helpful and skilled. If you look through the questions you might even find the very one you were about to ask (that happens to me very often).
DeleteThis comment has been removed by the author.
DeleteTake a look at http://www.financial-hacker.com/build-better-strategies-part-5-developing-a-machine-learning-system/
DeleteThis comment has been removed by the author.
ReplyDeleteHello!
ReplyDeleteI'm trying to use this package, which seems to have anything I need for a time series prediction, still I'm stuck at the beginning - what's the correct way of presenting data to the rnn?
My dataset is made of:
Y (outcome variable)
Week (1-52, week of the year)
X1-X3 (indipendent variables)
So I have Y = f(week, X1,X2,X3)
According to the docs,
"Y
array of output values, dim 1: samples (must be equal to dim 1 of X), dim 2: time (must be equal to dim 2 of X), dim 3: variables (could be 1 or more, if a matrix, will be coerce to array)
X
array of input values, dim 1: samples, dim 2: time, dim 3: variables (could be 1 or more, if a matrix, will be coerce to array)"
But I can't quite figure it out!
Thanks in advance for your time!
Hi, I'm afraid I can't really help you with your specific problem, but you can ask on StackExchange's crossvalidated website where there are plenty of knowledgeable people ready to answer you questions.
DeleteI have the same problem. Have you figured out how to solve it?
Delete