Clear Filters
Clear Filters

What is the order of weights in a LSTM layer with multiple LSTM neurons?

5 views (last 30 days)
says that weights are concatenated in this order W = [Wi Wf Wg Wo]. It does not talk about multiple LSTM neurons.
Take the case of 2 LSTM neurons.
Are the weights concatenated according to :
[Wi_neuron1 Wf_neuron1 Wg_neuron1 Wo_neuron1 Wi_neuron2 Wf_neuron2 Wg_neuron2 Wo_neuron2]
OR
[Wi_neuron1 Wi_neuron2 Wf_neuron1 Wf_neuron2 Wg_neuron1 Wg_neuron2 Wo_neuron1 Wo_neuron2]
In other words, are the gates grouped together for all neurons, or do the 4 gates repeat for all neurons?

Answers (2)

Sanjana
Sanjana on 21 Aug 2023
Hi,
I understand that you are facing an issue in understanding the order of weights in an LSTM Layer with multiple units.As per the documentation, each LSTM unit is associated with “Input gate”, “Forget gate”, “Cell candidate” and “Output gate”.
The learnable weights mentioned in the documentation are with respect to the LSTM layer with multiple units, but not the individual LSTM unit. By checking the ” learnables” of the following example “dlnetwork” object with LSTM layer, we can interpret the learnable weights as explained in the documentation. The learnable weights are obtained by the concatenation of weight matrices associated with the “Input gate”, “Forget gate”, “Cell candidate” and “Output gate” of the individual LSTM cells in the LSTM layer.
inputSize = 1;
embeddingDimension = 100;
numWords = 2920
numClasses = numWords + 1;
layers = [
sequenceInputLayer(inputSize)
wordEmbeddingLayer(embeddingDimension,numWords)
lstmLayer(100)
dropoutLayer(0.2)
fullyConnectedLayer(numClasses)
softmaxLayer];
lgraph = layerGraph(layers);
dlnet = dlnetwork(lgraph);
Below is the above “dlnetwork” object “dlnet.Learnables”,
Please refer to the following documentation, for more information,
Hope this helps.
Regards,
Sanjana.

findOut
findOut on 21 Aug 2023
I did some experiments to convince myself that the weights are are concatenated in the following manner. If there are 10 LSTM neurons, the first 10 rows will be input gate weights, the next 10 will be forget gate weights, next 10 cell candidate weights, and final 10 rows will be output gate. I confirmed this with a Mathworks technical support person.

Tags

Products


Release

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!