How to retrieve the cell/hidden state of an LSTM layer during training

21 views (last 30 days)
Valentin Steininger
Valentin Steininger on 4 Jul 2019
Answered: Yildirim Kocoglu on 10 Jan 2021
Hi everyone,
as the title says, I'm trying to extract the cell & hidden state from an LSTM layer after training. Unfortunately, I haven't found a solution for that yet.
Does anyone know, how that works or if it is even possible?
Thanks for any advice!

Answers (4)

Da-Ting Lin
Da-Ting Lin on 11 Feb 2020
I also have this question. Hopefully it may be included in an upcoming release?

Haoyuan Ma
Haoyuan Ma on 16 Mar 2020
I have this question too...
I have tried many times before seeing this page.

Giuseppe Dell'Aversana
Giuseppe Dell'Aversana on 16 Apr 2020
I also have this question.. maybe someone has the answer now?

Yildirim Kocoglu
Yildirim Kocoglu on 10 Jan 2021
It's a little late but, I had the same question and I came across this: https://www.mathworks.com/help/ident/ug/use-lstm-for-linear-system-identification.html
I haven't tried this yet but, please read this carefully as it may help.
Read the part: Set Network Initial State
It says: As the network performs estimation using a step input from 0 to 1, the states of the LSTM network (cell and hidden states of the LSTM layers) drift toward the correct initial condition. To visualize this, extract the cell and hidden state of the network at every time step using the predictAndUpdateState function.
Here is some code from the documentation which you can try to modify to achieve what you need:
stepMarker = time <= 2;
yhat = zeros(sum(stepMarker),1);
hiddenState = zeros(sum(stepMarker),200); % 200 LSTM units
cellState = zeros(sum(stepMarker),200);
for ntime = 1:sum(stepMarker)
[fourthOrderNet,yhat(ntime)] = predictAndUpdateState(fourthOrderNet,stepSignal(ntime)');
hiddenState(ntime,:) = fourthOrderNet.Layers(2,1).HiddenState;
cellState(ntime,:) = fourthOrderNet.Layers(2,1).CellState;
end
If you have multiple batches you can re-use the same batch in a for loop and just predict on your trained network (feed into the network one batch at a time like this for i=1:batch_size) and if you use net = resetState(net) (if you saved your trained network as 'net') at the very beginning of each prediction in the for loop it resets the states to initial states (which is usually zeros if you did not specify them beforehand). It is the same initial states used during your training so, you should be able to see the hiddenstates and cell states of each time step according to the code provided for each batch.
I personally needed to extract the final states to continue the prediction because I'm working on a forecasting problem.

Products


Release

R2018b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!