LSTM with vector as output for multi step ahead forecasting

10 views (last 30 days)
I would like to build a LSTM network that outputs a vector only on the last step (Outputmode = last).
Input-data: Sequence with a fixed size of [4,1344] (4 features, 1344 steps).
Output-data: Vector of size [96,1] (Output on last step)
My attempt so far:
numFeatures = 4;
numResponses = size(YTrain{1},1);
% Input layer
layers = sequenceInputLayer(numFeatures);
% variation of layers and hidden units
for i = 1:LSTMDepth-1
layers = [layers;lstmLayer(numHiddenUnits,OutputMode="sequence")];
layers = [layers;lstmLayer(numHiddenUnits,OutputMode="last")];
% Output layers
layers = [ layers
% training options
maxEpochs = 300;
miniBatchSize = 20;
options = trainingOptions("adam", ...
ExecutionEnvironment="auto", ...
MaxEpochs=maxEpochs, ...
MiniBatchSize=miniBatchSize, ...
ValidationData={XValidation,YValidation}, ...
ValidationFrequency=30, ...
InitialLearnRate=params.InitialLearnRate, ...
LearnRateDropFactor=0.2, ...
LearnRateDropPeriod=15, ...
GradientThreshold=1, ...
Shuffle="never", ...
% Training: XTrain and YTrain are cell arrays
net = trainNetwork(XTrain,YTrain,layers,options);
Can someone help me how to build such a network?
Thanks in advance

Accepted Answer

Katja Mogalle
Katja Mogalle on 25 Apr 2023
Hello Alexander,
The network structure generally looks fine to me.
When I run your code I get the following error, is it the same as you're getting?
Error using trainNetwork
Invalid training data. For regression tasks, responses must be a vector, a matrix, or a 4-D array of real numeric responses. Responses must not contain NaNs.
To fix this error, you need to slightly reorganize your response data. Currently it's a NumObservations-by-1 cell array of NumResponses-by-1 matrices. However, you need to convert this to a large matrix of size NumObservations-by-NumResponses. To do this, you can use the following line of code:
>> YTrain = cat(2,YTrain{:})';
If you need more information, you can also look at the documentation page of trainNetwork: (the link points to the explanation of the response data formats)
I hope this helps.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!