Clear Filters
Clear Filters

Is it possible to share common weights and bias among different LSTM layers?

19 views (last 30 days)
I am building a network looks like the figure below.
There are three LSTM layers, namely LSTM_common_1, LSTM_common_2 and LSTM_common_3.
Can I retrict their weights and bias so that all of the LSTM_common_x shares the same set of weights and bias?
2020-02-05 17_50_15-Clipboard.png

Answers (1)

Conor Daly
Conor Daly on 17 Feb 2023
One way to share weights like this is to use nested layers -- layers which have learnable parameters defined by neural networks. The general idea is to create a layer which uses the shared sub-network (which in this case is just a single LSTM layer) as appropriate.
Here's an example for the case above:
classdef commonLSTMLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties (Learnable)
Network
end
methods
function this = commonLSTMLayer(numHiddenUnits, numOutputs, args)
arguments
numHiddenUnits (1,1) {mustBePositive, mustBeInteger}
numOutputs (1,1) {mustBePositive, mustBeInteger}
args.OutputMode {mustBeMember(args.OutputMode, ["last","sequence"])}= "sequence"
args.Name {mustBeTextScalar}
end
this.Name = args.Name;
layer = lstmLayer(numHiddenUnits, OutputMode=args.OutputMode);
this.Network = dlnetwork(layer, Initialize=false);
this.NumOutputs = numOutputs;
this.OutputNames = "out" + (1:numOutputs);
end
function varargout = predict(this, X)
varargout = cell(1,this.NumOutputs);
for n = 1:this.NumOutputs
varargout{n} = predict(this.Network, X(n,:,:));
end
end
end
end
Using this layer we can construct the network as follows:
numInputChannels = 3;
numHiddenUnits = 64;
layers = [ sequenceInputLayer(numInputChannels)
commonLSTMLayer(numHiddenUnits, numInputChannels, OutputMode="last", Name="lstm")
fullyConnectedLayer(2, Name="fc1")
concatenationLayer(1, 3, Name="cat")
regressionLayer() ];
lg = layerGraph(layers);
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc2"));
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc3"));
lg = connectLayers(lg, "lstm/out2", "fc2");
lg = connectLayers(lg, "lstm/out3", "fc3");
lg = connectLayers(lg, "fc2", "cat/in2");
lg = connectLayers(lg, "fc3", "cat/in3");
analyzeNetwork(lg)

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!