どのように Stacked Autoencoder でデコード処理までを含ませることができますか?

Stacked Autoencoderを用いて入力画像を再復元したいと考えています。
EncoderとDecoderが、ともに複数の層をもつようなネットワークを作成する方法を、教えてください。

 Accepted Answer

MathWorks Support Team
MathWorks Support Team on 30 Sep 2020
Edited: MathWorks Support Team on 30 Sep 2020
スタックされたネットワークに対して、 Decoder 側を追加するようなスマートな機能は、Neural Network Toolbox (R2017a) では提供されておりません。
Toolbox で、AutoEncoder の 機能が提供される以前に、 AutoEncoder を実現していた例題が下記(※)があり、カスタムネットワークを作成いただくような形で、進めていただくような方法となります。
※ MathWorks Accont によるログインが必要となりますこと、ご了承ください。
上記については、通常の、 AutoEncoder による「識別」を目的とした例題となっておりますので、カスタムネットワークの定義の全体像の紹介までとなりますが、実際に、スタックされたネットワークに対して、 Decoder 側を追加するような方法としては、次のようなスクリプト例が基本構成となります。
X = abalone_dataset; % 8x4177 データを読み込み
% 入力、隠れ層のユニット数
inp_num = size(X,1); % 入力層
hid_num1 = 7; % 隠れ層 1
hid_num2 = 6 ; % 隠れ層 2
% AutoEncoder 実施
auto1 = trainAutoencoder(X,hid_num1,'DecoderTransferFunction','pureline','Scale',false);
tmp1 = encode(auto1,X);
auto2 = trainAutoencoder(tmp1,hid_num2,'DecoderTransferFunction','purelin','Scale',false);
tmp2 = encode(auto2,tmp1);
% 全体のネットワークを作成
fnet = network;
fnet.numInputs = 1;
fnet.numLayers = 4;
fnet.inputConnect(1,1) = 1;
fnet.layerConnect(2,1) = 1;
fnet.layerConnect(3,2) = 1;
fnet.layerConnect(4,3) = 1;
fnet.outputConnect(1,4) = 1;
fnet.biasConnect = [1;1;1;1];
fnet.inputs{1}.size = inp_num;
fnet.layers{1}.size = hid_num1;
fnet.layers{2}.size = hid_num2;
fnet.layers{3}.size = hid_num1;
fnet.layers{4}.size = inp_num;
fnet.layers{1}.transferFcn = 'logsig';
fnet.layers{2}.transferFcn = 'logsig';
fnet.layers{3}.transferFcn = 'purelin';
fnet.layers{4}.transferFcn = 'purelin';
fnet.divideFcn = 'dividetrain';
fnet.IW{1,1} = auto1.EncoderWeights;
fnet.b{1} = auto1.EncoderBiases;
fnet.LW{2,1} = auto2.EncoderWeights;
fnet.b{2} = auto2.EncoderBiases;
fnet.LW{3,2} = auto2.DecoderWeights;
fnet.b{3} = auto2.DecoderBiases;
fnet.LW{4,3} = auto1.DecoderWeights;
fnet.b{4} = auto1.DecoderBiases;
fnet.trainFcn = 'trainscg';
fnet.trainParam.epochs = 1000;
fnet = train(fnet,X,X);
Y = fnet(X);
hold on
for n = 1:inp_num
plot(X(n,:),Y(n,:),'.')
end
grid on
plot([-0.5 3.5],[-0.5 3.5])

More Answers (0)

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!